[DL]DRRN
import os import sys import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image,ImageFilter import matplotlib.pyplot as plt import numpy as np import torchvision from torchvision import transforms from torchvision.transforms import ToTensor from torchvision.transforms import ToPILImage from torch.autograd import Variable import math import random means = [0.485, 0.456, 0.406] stds = [ 0.229, 0.224, 0.225] pic_height = 512 pic_width = 512 opt = { "pic_c":32, # middle convolution layers is 32 "pic_h":pic_height, # low responsity images' height is 48 "pic_w":pic_width, "pic_up":2, # upscale "heads":8, "vec_dim":512, "N":6, "x_vocab_len":0, "y_vocab_len":0, "sentence_len":80, "batchs":1000, "batch_size":10, "pad":0, "sof":1, "eof":2, "cat_padding": 5, } def verse_normalize(img, means, stds): means = torch.Tensor(means) stds = torch.Tensor(stds) for i in range(3): img[:,i] = img[:,i] * stds[i] + means[i] return img """ # errors def psnr(img1, img2): img1 = Variable(img1) img2 = Variable(img2) mse = ( (img1/1.0) - (img2/1.0) ) ** 2 mse = np.mean( np.array(mse) ) if mse < 1.0e-10: return 100 return 10 * math.log10(255.0**2/mse) """ def psnr(img1, img2): img1 = Variable(img1) img2 = Variable(img2) mse = ( (img1/255.0) - (img2/255.0) ) ** 2 mse = np.mean( np.array(mse) ) if mse < 1.0e-10: return 100 return 10 * math.log10(1/math.sqrt(mse)) def rgb2ycbcr(rgb_image): """convert rgb into ycbcr""" if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3: return rgb_image #raise ValueError("input image is not a rgb image") rgb_image = rgb_image.astype(np.float32) transform_matrix = np.array([[0.257, 0.564, 0.098], [-0.148, -0.291, 0.439], [0.439, -0.368, -0.071]]) shift_matrix = np.array([16, 128, 128]) ycbcr_image = np.zeros(shape=rgb_image.shape) w, h, _ = rgb_image.shape for i in range(w): for j in range(h): ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix return ycbcr_image def ycbcr2rgb(ycbcr_image): """convert ycbcr into rgb""" if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3: return ycbcr_image #raise ValueError("input image is not a rgb image") ycbcr_image = ycbcr_image.astype(np.float32) transform_matrix = np.array([[0.257, 0.564, 0.098], [-0.148, -0.291, 0.439], [0.439, -0.368, -0.071]]) transform_matrix_inv = np.linalg.inv(transform_matrix) shift_matrix = np.array([16, 128, 128]) rgb_image = np.zeros(shape=ycbcr_image.shape) w, h, _ = ycbcr_image.shape for i in range(w): for j in range(h): rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix) return rgb_image.astype(np.uint8) class dataloader: def __init__(self, path, batchs, batch_size, test= False, test_offset = 0.9, order = False): self.test = test self.test_offset = test_offset self.batchs_cnt = 0 self.batchs = batchs self.batch_size = batch_size self.path = path self.file_list = os.listdir(self.path) if order: self.file_list = sorted(self.file_list) else: random.shuffle(self.file_list) self.file_number = len(self.file_list) self.file_start = 0 self.file_stop = self.file_number * self.test_offset - 1 if self.test: self.file_start = self.file_number * self.test_offset self.file_stop = self.file_number - 1 self.file_start = int(np.floor(self.file_start)) self.file_stop = int(np.floor(self.file_stop)) self.file_idx = self.file_start self.downsample = transforms.Compose([ transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR), #transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ), #transforms.Grayscale(num_output_channels = 3), #transforms.ToTensor() #transforms.Normalize(means, stds) ]) self.transform = transforms.Compose([ #transforms.Resize( (pic_height,pic_width), Image.BILINEAR), transforms.RandomResizedCrop( pic_height, scale=(1,1) ), #transforms.Grayscale(num_output_channels = 3), transforms.ToTensor() #transforms.Normalize(means, stds) ]) def get_len(self): return self.file_stop - self.file_start def __iter__(self): return self def __next__(self): if (self.batchs_cnt >= self.batchs) & (self.batchs > 0): self.batchs_cnt = 0 raise StopIteration self.batchs_cnt += 1 X = [] Y = [] for i in range( self.batch_size): X_, Y_ = self._next() X.append(X_) Y.append(Y_) X = torch.stack(X, 0) Y = torch.stack(Y, 0) return X, Y def _next(self): if self.file_idx >= self.file_stop: self.file_idx = self.file_start file_path = self.path + '/' + self.file_list[self.file_idx] self.file_idx += 1 Y = Image.open(file_path) #print( "Y:", file_path, " (", Y.size ) if len(Y.getbands()) == 1: Y = Y.convert("RGB") #X = self.downsample(Y) X = Y.filter(ImageFilter.BLUR) X = self.transform(X) Y = self.transform(Y) return X,Y class DRRN(nn.Module): def __init__(self): super(DRRN, self).__init__() # 128 channels in hidden layers c = 32 self.input = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=3, stride=1, padding=1, bias=False) self.conv1 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, stride=1, padding=1, bias=False) self.conv2 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, stride=1, padding=1, bias=False) self.output = nn.Conv2d(in_channels=c, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) self.relu = nn.ReLU(inplace=True) def init_weights(self): print(" ****************************") for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) def forward(self, x): residual = x inputs = self.input(self.relu(x)) out = inputs for _ in range(2): #25 out = self.conv2(self.relu(self.conv1(self.relu(out)))) out = torch.add(out, inputs) out = self.output(self.relu(out)) out = torch.add(out, residual) return out tensor_to_pil = ToPILImage() def display_img(img1, img2): plt.subplot(1, 2, 1) plt.imshow(img1) plt.subplot(1, 2, 2) plt.imshow(img2) plt.show() ''' dataloader = dataloader('./dataset/Set14/image_SRF_2') plt.figure("ddd") plt.ion() plt.cla() for X,Y in dataloader: plt.imshow(X) plt.pause(0.1) plt.ioff() iii = tensor_to_pil(Y_[0]) plt.imshow(iii) plt.show() ''' model1 = DRRN() model2 = DRRN() model3 = DRRN() loss = nn.MSELoss() optimizer1 = torch.optim.Adam(model1.parameters(), lr = 0.001, betas = (0.9, 0.999)) optimizer2 = torch.optim.Adam(model2.parameters(), lr = 0.001, betas = (0.9, 0.999)) optimizer3 = torch.optim.Adam(model3.parameters(), lr = 0.001, betas = (0.9, 0.999)) dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', 10, 2) #dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2) dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 2) new_training = 1 def train(): global model1 global model2 global model3 if new_training == 1: print(" new training ...") model1.init_weights() model2.init_weights() model3.init_weights() else: print(" continue training ...") model1= torch.load("./model/srcnn_l1.m") model2= torch.load("./model/srcnn_l2.m") model3= torch.load("./model/srcnn_l3.m") running_loss = 0 plt.ion() i = 0 for X,Y in dataloader_train: #print(X.shape) ''' Train ''' model1.train() model2.train() model3.train() _,c,h,w = X.shape Y_1 = model1(X[:,0].reshape(-1,1,h,w)) optimizer1.zero_grad() cost1 = loss(Y_1, Y[:,0].reshape(-1,1,h,w)) cost1.backward() optimizer1.step() Y_2 = model2(X[:,1].reshape(-1,1,h,w)) optimizer2.zero_grad() cost2 = loss(Y_2, Y[:,1].reshape(-1,1,h,w)) cost2.backward() optimizer2.step() Y_3 = model3(X[:,2].reshape(-1,1,h,w)) optimizer3.zero_grad() cost3 = loss(Y_3, Y[:,2].reshape(-1,1,h,w)) cost3.backward() optimizer3.step() #running_loss += cost.item() print( "batch:{}, loss is {:.4f} {:.4f} {:.4f} ".format(i, cost1, cost2, cost3) ) i += 1 if i%1 == 0: test(model1, model2, model3) ''' Display images ''' """ Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1) iii = tensor_to_pil(Y_[0]) plt.subplot(1, 2, 1) plt.imshow(iii) iii = tensor_to_pil(X[0]) plt.subplot(1, 2, 2) plt.imshow(iii) plt.show() plt.pause(0.1) """ plt.ioff() torch.save(model1, "./model/srcnn_l1.m") torch.save(model2, "./model/srcnn_l2.m") torch.save(model3, "./model/srcnn_l3.m") def test(model1, model2, model3): running_loss = 0 #plt.ion() for X,Y in dataloader_test: #print(X.shape) model1.eval() model2.eval() model3.eval() _,c,h,w = X.shape Y_1 = model1(X[:,0].reshape(-1,1,h,w)) Y_2 = model2(X[:,1].reshape(-1,1,h,w)) Y_3 = model3(X[:,2].reshape(-1,1,h,w)) print( c, h, w) Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1) psnrv = psnr(Y_, Y) print(" Psnr = ", psnr(Y_,Y), " psnr2 = ", psnr(X, Y) ) ''' Display images ''' iii = tensor_to_pil(Y_[0]) plt.subplot(1, 2, 1) plt.imshow(iii) iii = tensor_to_pil(X[0]) plt.subplot(1, 2, 2) plt.imshow(iii) plt.show() plt.pause(0.1) #plt.ioff() def run(): model1= torch.load("./model/srcnn_l1.m") model2= torch.load("./model/srcnn_l2.m") model3= torch.load("./model/srcnn_l3.m") model1.eval() model2.eval() model3.eval() path = "./dataset/Set5/image_SRF_2" file_list = os.listdir(path) #plt.ion() # block size of LR images lh = opt['pic_h'] lw = opt['pic_w'] # the stride of moving blocks processing stride_h = lh - opt['cat_padding'] * 2 stride_w = lw - opt['cat_padding'] * 2 # the center sub block of model outputings y_top = (opt['cat_padding']) y_left = (opt['cat_padding']) y_down = (lh - opt['cat_padding']) y_right = (lw - opt['cat_padding']) for img_name in file_list: # get images img = Image.open(path+'/'+img_name) if len(img.getbands()) == 1: img = img.convert("RGB") c = 3 org_img = img org_w, org_h = org_img.size dst_w, dst_h = org_w * opt['pic_up'], org_h * opt['pic_up'] transform_up = transforms.Compose([ #transforms.RandomResizedCrop(pad_h, scale=(2.0,2.0)), transforms.Resize( (dst_h, dst_w) ), #transforms.CenterCrop( (pad_h, pad_w) ), ]) img = transform_up(img) # padding the edge, turn ( h X w ) to ( uh X uw ) pad_h = dst_h + opt['cat_padding'] * 2 pad_w = dst_w + opt['cat_padding'] * 2 pad_h_off = ((pad_h - 2*opt['cat_padding']) % stride_h) pad_w_off = ((pad_w - 2*opt['cat_padding']) % stride_w) if pad_h_off != 0: pad_h += stride_h - pad_h_off if pad_w_off != 0: pad_w += stride_w - pad_w_off # pad_h must == pad_w transform_pad = transforms.Compose([ #transforms.RandomResizedCrop(pad_h, scale=(2.0,2.0)), transforms.CenterCrop( (pad_h, pad_w) ), ]) img = transform_pad(img) w,h = img.size # srcnn : preprocessing for images. So set size to sr_h, sr_w. transform_mid = transforms.Compose([ #transforms.CenterCrop( (sr_h, sr_w) ), transforms.ToTensor(), #transforms.Normalize(means, stds) ]) transform_last = transforms.Compose([ transforms.CenterCrop( (dst_h, dst_w) ), transforms.ToTensor(), #transforms.Normalize(means, stds), as the Y_ has been verse normalized. ]) Y = torch.zeros( (1, 3, h, w) ) img_ = transform_mid(img) top = 0 down = top + lh while down <= h: left = 0 right = left + lw while right <= w: X_ = img_[:, top : down, left : right].detach() X_ = X_.reshape(-1, c, lh, lw) Y_1 = model1(X_[:,0].reshape(-1,1,h,w)) Y_2 = model2(X_[:,1].reshape(-1,1,h,w)) Y_3 = model3(X_[:,2].reshape(-1,1,h,w)) Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1) cent_top = (top + opt['cat_padding']) cent_left = (left + opt['cat_padding']) cent_down = (down - opt['cat_padding']) cent_right = (right - opt['cat_padding']) Y[0,:, cent_top : cent_down, cent_left: cent_right] = Y_[0,:, y_top: y_down, y_left: y_right] # move patchs to right left += stride_w right = left + lw # move patchs to down top += stride_h down = top + lh Y = verse_normalize(Y, means, stds) Y = Y[0] Y = tensor_to_pil(Y) Y = transform_last(Y) Y = tensor_to_pil(Y) Y.save("./DRRN.png") display_img( Y, org_img ) plt.pause(1) plt.clf() """ For training run this command: python face_cnn.py train For testing fun this command: python face_cnn.py test """ if __name__ == '__main__': args = sys.argv[1:] print( args, len(args)) if (len(args) == 1) & (args[0] == "train"): train() elif (len(args) == 1) & (args[0] == "run"): run() else: test()
标签: python machine_learning DL
日历
最新微语
- 有的时候,会站在分叉路口,不知道向左还是右
2023-12-26 15:34
- 繁花乱开,鸟雀逐风。心自宁静,纷扰不闻。
2023-03-14 09:56
- 对于不可控的事,我们保持乐观,对于可控的事情,我们保持谨慎。
2023-02-09 11:03
- 小时候,
暑假意味着无忧无虑地玩很长一段时间,
节假意味着好吃好喝还有很多长期不见的小朋友来玩...
长大后,
这是女儿第一个暑假,
一个半月...
2022-07-11 08:54
- Watching the autumn leaves falling as you grow older together
2018-10-25 09:45
分类
最新评论
- Goonog
i get it now :) - 萧
@Fluzak:The web host... - Fluzak
Nice blog here! Also... - Albertarive
In my opinion you co... - ChesterHep
What does it plan? - ChesterHep
No, opposite. - mojoheadz
Everything is OK!... - Josephmaigh
I just want to say t... - ChesterHep
What good topic - AnthonyBub
Certainly, never it ...
发表评论: