[DL]SRCNN:Super-resolution with CNN
import os import sys import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image,ImageFilter import matplotlib.pyplot as plt import numpy as np import torchvision from torchvision import transforms from torchvision.transforms import ToTensor from torchvision.transforms import ToPILImage from torch.autograd import Variable import math import random means = [0.485, 0.456, 0.406] stds = [ 0.229, 0.224, 0.225] pic_height = 512 pic_width = 512 """ # errors def psnr(img1, img2): img1 = Variable(img1) img2 = Variable(img2) mse = ( (img1/1.0) - (img2/1.0) ) ** 2 mse = np.mean( np.array(mse) ) if mse < 1.0e-10: return 100 return 10 * math.log10(255.0**2/mse) """ def psnr(img1, img2): img1 = Variable(img1) img2 = Variable(img2) mse = ( (img1/255.0) - (img2/255.0) ) ** 2 mse = np.mean( np.array(mse) ) if mse < 1.0e-10: return 100 return 10 * math.log10(1/math.sqrt(mse)) def rgb2ycbcr(rgb_image): """convert rgb into ycbcr""" if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3: return rgb_image #raise ValueError("input image is not a rgb image") rgb_image = rgb_image.astype(np.float32) transform_matrix = np.array([[0.257, 0.564, 0.098], [-0.148, -0.291, 0.439], [0.439, -0.368, -0.071]]) shift_matrix = np.array([16, 128, 128]) ycbcr_image = np.zeros(shape=rgb_image.shape) w, h, _ = rgb_image.shape for i in range(w): for j in range(h): ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix return ycbcr_image def ycbcr2rgb(ycbcr_image): """convert ycbcr into rgb""" if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3: return ycbcr_image #raise ValueError("input image is not a rgb image") ycbcr_image = ycbcr_image.astype(np.float32) transform_matrix = np.array([[0.257, 0.564, 0.098], [-0.148, -0.291, 0.439], [0.439, -0.368, -0.071]]) transform_matrix_inv = np.linalg.inv(transform_matrix) shift_matrix = np.array([16, 128, 128]) rgb_image = np.zeros(shape=ycbcr_image.shape) w, h, _ = ycbcr_image.shape for i in range(w): for j in range(h): rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix) return rgb_image.astype(np.uint8) class dataloader: def __init__(self, path, batchs, batch_size, test= False, test_offset = 0.9, order = False): self.test = test self.test_offset = test_offset self.batchs_cnt = 0 self.batchs = batchs self.batch_size = batch_size self.path = path self.file_list = os.listdir(self.path) if order: self.file_list = sorted(self.file_list) else: random.shuffle(self.file_list) self.file_number = len(self.file_list) self.file_start = 0 self.file_stop = self.file_number * self.test_offset - 1 if self.test: self.file_start = self.file_number * self.test_offset self.file_stop = self.file_number - 1 self.file_start = int(np.floor(self.file_start)) self.file_stop = int(np.floor(self.file_stop)) self.file_idx = self.file_start self.downsample = transforms.Compose([ transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR), #transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ), #transforms.Grayscale(num_output_channels = 3), #transforms.ToTensor() #transforms.Normalize(means, stds) ]) self.transform = transforms.Compose([ #transforms.Resize( (pic_height,pic_width), Image.BILINEAR), transforms.RandomResizedCrop( pic_height, scale=(1,1) ), #transforms.Grayscale(num_output_channels = 3), transforms.ToTensor() #transforms.Normalize(means, stds) ]) def get_len(self): return self.file_stop - self.file_start def __iter__(self): return self def __next__(self): if (self.batchs_cnt >= self.batchs) & (self.batchs > 0): self.batchs_cnt = 0 raise StopIteration self.batchs_cnt += 1 X = [] Y = [] for i in range( self.batch_size): X_, Y_ = self._next() X.append(X_) Y.append(Y_) X = torch.stack(X, 0) Y = torch.stack(Y, 0) return X, Y def _next(self): if self.file_idx >= self.file_stop: self.file_idx = self.file_start file_path = self.path + '/' + self.file_list[self.file_idx] self.file_idx += 1 Y = Image.open(file_path) #print( "Y:", file_path, " (", Y.size ) if len(Y.getbands()) == 1: Y = Y.convert("RGB") #X = self.downsample(Y) X = Y.filter(ImageFilter.BLUR) X = self.transform(X) Y = self.transform(Y) return X,Y class SRCNN(nn.Module): def __init__(self): super().__init__() self.patch_extraction = nn.Conv2d(1, 64, kernel_size = 9, stride = 1, padding=4, padding_mode='replicate') self.non_linear = nn.Conv2d(64, 32, kernel_size = 1, stride = 1, padding = 0) self.reconstruction = nn.Conv2d(32, 1, kernel_size = 5, stride = 1, padding = 2, padding_mode='replicate') def init_weights(self): print(" ****************************") self.patch_extraction.weight.data.normal_(mean=0.0, std=0.001) self.patch_extraction.bias.data.zero_() self.non_linear.weight.data.normal_(mean=0.0, std=0.001) self.non_linear.bias.data.zero_() self.reconstruction.weight.data.normal_(mean=0.0, std=0.001) self.reconstruction.bias.data.zero_() def forward(self, x): fm_1 = F.relu(self.patch_extraction(x)) fm_2 = F.relu(self.non_linear(fm_1)) #fm_3 = F.sigmoid(self.reconstruction(fm_2)) fm_3 = self.reconstruction(fm_2) return fm_3 tensor_to_pil = ToPILImage() ''' dataloader = dataloader('./dataset/Set14/image_SRF_2') plt.figure("ddd") plt.ion() plt.cla() for X,Y in dataloader: plt.imshow(X) plt.pause(0.1) plt.ioff() iii = tensor_to_pil(Y_[0]) plt.imshow(iii) plt.show() ''' model1 = SRCNN() model2 = SRCNN() model3 = SRCNN() loss = nn.MSELoss() """ optimizer1 = torch.optim.SGD( [ {"params":model1.patch_extraction.parameters(), "lr":0.0001}, {"params":model1.non_linear.parameters(), "lr":0.0001}, {"params":model1.reconstruction.parameters(), "lr":0.00001}, ]) optimizer2 = torch.optim.SGD( [ {"params":model2.patch_extraction.parameters(), "lr":0.0001}, {"params":model2.non_linear.parameters(), "lr":0.0001}, {"params":model2.reconstruction.parameters(), "lr":0.00001}, ]) optimizer3 = torch.optim.SGD( [ {"params":model3.patch_extraction.parameters(), "lr":0.0001}, {"params":model3.non_linear.parameters(), "lr":0.0001}, {"params":model3.reconstruction.parameters(), "lr":0.00001}, ]) """ optimizer1 = torch.optim.Adam( [ {"params":model1.patch_extraction.parameters(), "lr":0.0001}, {"params":model1.non_linear.parameters(), "lr":0.0001}, {"params":model1.reconstruction.parameters(), "lr":0.00001}, ]) optimizer2 = torch.optim.Adam( [ {"params":model2.patch_extraction.parameters(), "lr":0.0001}, {"params":model2.non_linear.parameters(), "lr":0.0001}, {"params":model2.reconstruction.parameters(), "lr":0.00001}, ]) optimizer3 = torch.optim.Adam( [ {"params":model3.patch_extraction.parameters(), "lr":0.0001}, {"params":model3.non_linear.parameters(), "lr":0.0001}, {"params":model3.reconstruction.parameters(), "lr":0.00001}, ]) """ optimizer1 = torch.optim.Adam(model1.parameters(), lr = 0.001, betas = (0.9, 0.999)) optimizer2 = torch.optim.Adam(model2.parameters(), lr = 0.001, betas = (0.9, 0.999)) optimizer3 = torch.optim.Adam(model3.parameters(), lr = 0.001, betas = (0.9, 0.999)) """ dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', 200, 2) #dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2) dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 2) new_training = 0 def train(): if new_training == 1: print(" new training ...") model1.init_weights() model2.init_weights() model3.init_weights() else: print(" continue training ...") model1= torch.load("./model/srcnn_l1.m") model2= torch.load("./model/srcnn_l2.m") model3= torch.load("./model/srcnn_l3.m") running_loss = 0 plt.ion() i = 0 for X,Y in dataloader_train: #print(X.shape) ''' Train ''' model1.train() model2.train() model3.train() _,c,h,w = X.shape Y_1 = model1(X[:,0].reshape(-1,1,h,w)) optimizer1.zero_grad() cost1 = loss(Y_1, Y[:,0].reshape(-1,1,h,w)) cost1.backward() optimizer1.step() Y_2 = model2(X[:,1].reshape(-1,1,h,w)) optimizer2.zero_grad() cost2 = loss(Y_2, Y[:,1].reshape(-1,1,h,w)) cost2.backward() optimizer2.step() Y_3 = model3(X[:,2].reshape(-1,1,h,w)) optimizer3.zero_grad() cost3 = loss(Y_3, Y[:,2].reshape(-1,1,h,w)) cost3.backward() optimizer3.step() #running_loss += cost.item() print( "batch:{}, loss is {:.4f} {:.4f} {:.4f} ".format(i, cost1, cost2, cost3) ) i += 1 if i%10 == 0: test(model1, model2, model3) ''' Display images ''' """ Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1) iii = tensor_to_pil(Y_[0]) plt.subplot(1, 2, 1) plt.imshow(iii) iii = tensor_to_pil(X[0]) plt.subplot(1, 2, 2) plt.imshow(iii) plt.show() plt.pause(0.1) """ plt.ioff() torch.save(model1, "./model/srcnn_l1.m") torch.save(model2, "./model/srcnn_l2.m") torch.save(model3, "./model/srcnn_l3.m") def test(model1, model2, model3): running_loss = 0 #plt.ion() for X,Y in dataloader_test: #print(X.shape) model1.eval() model2.eval() model3.eval() _,c,h,w = X.shape Y_1 = model1(X[:,0].reshape(-1,1,h,w)) Y_2 = model2(X[:,1].reshape(-1,1,h,w)) Y_3 = model3(X[:,2].reshape(-1,1,h,w)) print( c, h, w) Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1) psnrv = psnr(Y_, Y) print(" Psnr = ", psnr(Y_,Y), " psnr2 = ", psnr(X, Y) ) ''' Display images ''' iii = tensor_to_pil(Y_[0]) plt.subplot(1, 2, 1) plt.imshow(iii) iii = tensor_to_pil(X[0]) plt.subplot(1, 2, 2) plt.imshow(iii) plt.show() plt.pause(0.1) #plt.ioff() def run(): model1= torch.load("./model/srcnn_l1.m") model2= torch.load("./model/srcnn_l2.m") model3= torch.load("./model/srcnn_l3.m") model1.eval() model2.eval() model3.eval() path = "./images" file_list = os.listdir(path) #plt.ion() transform = transforms.Compose([ transforms.RandomResizedCrop( pic_height, scale=(1,1) ), transforms.ToTensor() ]) for img_name in file_list: img = Image.open(path+'/'+img_name) if len(img.getbands()) == 1: img = img.convert("RGB") X = transform(img) c,h,w = X.shape X = X.reshape(1,c,h,w) _,c,h,w = X.shape Y_1 = model1(X[:,0].reshape(-1,1,h,w)) Y_2 = model2(X[:,1].reshape(-1,1,h,w)) Y_3 = model3(X[:,2].reshape(-1,1,h,w)) Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1) ''' Display images ''' iii = tensor_to_pil(Y_[0]) plt.subplot(1, 2, 1) plt.imshow(iii) iii = tensor_to_pil(X[0]) plt.subplot(1, 2, 2) plt.imshow(iii) plt.show() plt.pause(1) #plt.ioff() """ For training run this command: python face_cnn.py train For testing fun this command: python face_cnn.py test """ if __name__ == '__main__': args = sys.argv[1:] print( args, len(args)) if (len(args) == 1) & (args[0] == "train"): train() elif (len(args) == 1) & (args[0] == "run"): run() else: test()
标签: DL
日历
最新微语
- Watching the autumn leaves falling as you grow older together
2018-10-25 09:45
- 时间不可以倒流,但空间可以
2017-08-01 09:03
- 含羞草、电磁炮;汽车工业革命
2017-05-23 22:51
- 那个点子页面加几点:
去中心化的物联网通信协议
2017-05-09 22:13
- 有一种人怀疑阴阳的存在,另有一种人会怀疑1+1=2的正确性……
2017-03-01 17:08
分类
最新评论
- 萧
@Fluzak:The web host... - Fluzak
Nice blog here! Also... - Albertarive
In my opinion you co... - ChesterHep
What does it plan? - ChesterHep
No, opposite. - mojoheadz
Everything is OK!... - Josephmaigh
I just want to say t... - ChesterHep
What good topic - AnthonyBub
Certainly, never it ... - DavidNed
I think, that you ar...
发表评论: