[DL]SRCNN:Super-resolution with CNN
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random
means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]
pic_height = 512
pic_width = 512
"""
# errors
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/1.0) - (img2/1.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)
"""
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/255.0) - (img2/255.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(1/math.sqrt(mse))
def rgb2ycbcr(rgb_image):
"""convert rgb into ycbcr"""
if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
return rgb_image
#raise ValueError("input image is not a rgb image")
rgb_image = rgb_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
shift_matrix = np.array([16, 128, 128])
ycbcr_image = np.zeros(shape=rgb_image.shape)
w, h, _ = rgb_image.shape
for i in range(w):
for j in range(h):
ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
return ycbcr_image
def ycbcr2rgb(ycbcr_image):
"""convert ycbcr into rgb"""
if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
return ycbcr_image
#raise ValueError("input image is not a rgb image")
ycbcr_image = ycbcr_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
transform_matrix_inv = np.linalg.inv(transform_matrix)
shift_matrix = np.array([16, 128, 128])
rgb_image = np.zeros(shape=ycbcr_image.shape)
w, h, _ = ycbcr_image.shape
for i in range(w):
for j in range(h):
rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
return rgb_image.astype(np.uint8)
class dataloader:
def __init__(self, path, batchs, batch_size, test= False, test_offset = 0.9, order = False):
self.test = test
self.test_offset = test_offset
self.batchs_cnt = 0
self.batchs = batchs
self.batch_size = batch_size
self.path = path
self.file_list = os.listdir(self.path)
if order:
self.file_list = sorted(self.file_list)
else:
random.shuffle(self.file_list)
self.file_number = len(self.file_list)
self.file_start = 0
self.file_stop = self.file_number * self.test_offset - 1
if self.test:
self.file_start = self.file_number * self.test_offset
self.file_stop = self.file_number - 1
self.file_start = int(np.floor(self.file_start))
self.file_stop = int(np.floor(self.file_stop))
self.file_idx = self.file_start
self.downsample = transforms.Compose([
transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
#transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
#transforms.Grayscale(num_output_channels = 3),
#transforms.ToTensor()
#transforms.Normalize(means, stds)
])
self.transform = transforms.Compose([
#transforms.Resize( (pic_height,pic_width), Image.BILINEAR),
transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor()
#transforms.Normalize(means, stds)
])
def get_len(self):
return self.file_stop - self.file_start
def __iter__(self):
return self
def __next__(self):
if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
self.batchs_cnt = 0
raise StopIteration
self.batchs_cnt += 1
X = []
Y = []
for i in range( self.batch_size):
X_, Y_ = self._next()
X.append(X_)
Y.append(Y_)
X = torch.stack(X, 0)
Y = torch.stack(Y, 0)
return X, Y
def _next(self):
if self.file_idx >= self.file_stop:
self.file_idx = self.file_start
file_path = self.path + '/' + self.file_list[self.file_idx]
self.file_idx += 1
Y = Image.open(file_path)
#print( "Y:", file_path, " (", Y.size )
if len(Y.getbands()) == 1:
Y = Y.convert("RGB")
#X = self.downsample(Y)
X = Y.filter(ImageFilter.BLUR)
X = self.transform(X)
Y = self.transform(Y)
return X,Y
class SRCNN(nn.Module):
def __init__(self):
super().__init__()
self.patch_extraction = nn.Conv2d(1, 64, kernel_size = 9, stride = 1, padding=4, padding_mode='replicate')
self.non_linear = nn.Conv2d(64, 32, kernel_size = 1, stride = 1, padding = 0)
self.reconstruction = nn.Conv2d(32, 1, kernel_size = 5, stride = 1, padding = 2, padding_mode='replicate')
def init_weights(self):
print(" ****************************")
self.patch_extraction.weight.data.normal_(mean=0.0, std=0.001)
self.patch_extraction.bias.data.zero_()
self.non_linear.weight.data.normal_(mean=0.0, std=0.001)
self.non_linear.bias.data.zero_()
self.reconstruction.weight.data.normal_(mean=0.0, std=0.001)
self.reconstruction.bias.data.zero_()
def forward(self, x):
fm_1 = F.relu(self.patch_extraction(x))
fm_2 = F.relu(self.non_linear(fm_1))
#fm_3 = F.sigmoid(self.reconstruction(fm_2))
fm_3 = self.reconstruction(fm_2)
return fm_3
tensor_to_pil = ToPILImage()
'''
dataloader = dataloader('./dataset/Set14/image_SRF_2')
plt.figure("ddd")
plt.ion()
plt.cla()
for X,Y in dataloader:
plt.imshow(X)
plt.pause(0.1)
plt.ioff()
iii = tensor_to_pil(Y_[0])
plt.imshow(iii)
plt.show()
'''
model1 = SRCNN()
model2 = SRCNN()
model3 = SRCNN()
loss = nn.MSELoss()
"""
optimizer1 = torch.optim.SGD( [
{"params":model1.patch_extraction.parameters(), "lr":0.0001},
{"params":model1.non_linear.parameters(), "lr":0.0001},
{"params":model1.reconstruction.parameters(), "lr":0.00001},
])
optimizer2 = torch.optim.SGD( [
{"params":model2.patch_extraction.parameters(), "lr":0.0001},
{"params":model2.non_linear.parameters(), "lr":0.0001},
{"params":model2.reconstruction.parameters(), "lr":0.00001},
])
optimizer3 = torch.optim.SGD( [
{"params":model3.patch_extraction.parameters(), "lr":0.0001},
{"params":model3.non_linear.parameters(), "lr":0.0001},
{"params":model3.reconstruction.parameters(), "lr":0.00001},
])
"""
optimizer1 = torch.optim.Adam( [
{"params":model1.patch_extraction.parameters(), "lr":0.0001},
{"params":model1.non_linear.parameters(), "lr":0.0001},
{"params":model1.reconstruction.parameters(), "lr":0.00001},
])
optimizer2 = torch.optim.Adam( [
{"params":model2.patch_extraction.parameters(), "lr":0.0001},
{"params":model2.non_linear.parameters(), "lr":0.0001},
{"params":model2.reconstruction.parameters(), "lr":0.00001},
])
optimizer3 = torch.optim.Adam( [
{"params":model3.patch_extraction.parameters(), "lr":0.0001},
{"params":model3.non_linear.parameters(), "lr":0.0001},
{"params":model3.reconstruction.parameters(), "lr":0.00001},
])
"""
optimizer1 = torch.optim.Adam(model1.parameters(), lr = 0.001, betas = (0.9, 0.999))
optimizer2 = torch.optim.Adam(model2.parameters(), lr = 0.001, betas = (0.9, 0.999))
optimizer3 = torch.optim.Adam(model3.parameters(), lr = 0.001, betas = (0.9, 0.999))
"""
dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', 200, 2)
#dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2)
dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 2)
new_training = 0
def train():
if new_training == 1:
print(" new training ...")
model1.init_weights()
model2.init_weights()
model3.init_weights()
else:
print(" continue training ...")
model1= torch.load("./model/srcnn_l1.m")
model2= torch.load("./model/srcnn_l2.m")
model3= torch.load("./model/srcnn_l3.m")
running_loss = 0
plt.ion()
i = 0
for X,Y in dataloader_train:
#print(X.shape)
'''
Train
'''
model1.train()
model2.train()
model3.train()
_,c,h,w = X.shape
Y_1 = model1(X[:,0].reshape(-1,1,h,w))
optimizer1.zero_grad()
cost1 = loss(Y_1, Y[:,0].reshape(-1,1,h,w))
cost1.backward()
optimizer1.step()
Y_2 = model2(X[:,1].reshape(-1,1,h,w))
optimizer2.zero_grad()
cost2 = loss(Y_2, Y[:,1].reshape(-1,1,h,w))
cost2.backward()
optimizer2.step()
Y_3 = model3(X[:,2].reshape(-1,1,h,w))
optimizer3.zero_grad()
cost3 = loss(Y_3, Y[:,2].reshape(-1,1,h,w))
cost3.backward()
optimizer3.step()
#running_loss += cost.item()
print( "batch:{}, loss is {:.4f} {:.4f} {:.4f} ".format(i, cost1, cost2, cost3) )
i += 1
if i%10 == 0:
test(model1, model2, model3)
'''
Display images
'''
"""
Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
iii = tensor_to_pil(Y_[0])
plt.subplot(1, 2, 1)
plt.imshow(iii)
iii = tensor_to_pil(X[0])
plt.subplot(1, 2, 2)
plt.imshow(iii)
plt.show()
plt.pause(0.1)
"""
plt.ioff()
torch.save(model1, "./model/srcnn_l1.m")
torch.save(model2, "./model/srcnn_l2.m")
torch.save(model3, "./model/srcnn_l3.m")
def test(model1, model2, model3):
running_loss = 0
#plt.ion()
for X,Y in dataloader_test:
#print(X.shape)
model1.eval()
model2.eval()
model3.eval()
_,c,h,w = X.shape
Y_1 = model1(X[:,0].reshape(-1,1,h,w))
Y_2 = model2(X[:,1].reshape(-1,1,h,w))
Y_3 = model3(X[:,2].reshape(-1,1,h,w))
print( c, h, w)
Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
psnrv = psnr(Y_, Y)
print(" Psnr = ", psnr(Y_,Y), " psnr2 = ", psnr(X, Y) )
'''
Display images
'''
iii = tensor_to_pil(Y_[0])
plt.subplot(1, 2, 1)
plt.imshow(iii)
iii = tensor_to_pil(X[0])
plt.subplot(1, 2, 2)
plt.imshow(iii)
plt.show()
plt.pause(0.1)
#plt.ioff()
def run():
model1= torch.load("./model/srcnn_l1.m")
model2= torch.load("./model/srcnn_l2.m")
model3= torch.load("./model/srcnn_l3.m")
model1.eval()
model2.eval()
model3.eval()
path = "./images"
file_list = os.listdir(path)
#plt.ion()
transform = transforms.Compose([
transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
transforms.ToTensor()
])
for img_name in file_list:
img = Image.open(path+'/'+img_name)
if len(img.getbands()) == 1:
img = img.convert("RGB")
X = transform(img)
c,h,w = X.shape
X = X.reshape(1,c,h,w)
_,c,h,w = X.shape
Y_1 = model1(X[:,0].reshape(-1,1,h,w))
Y_2 = model2(X[:,1].reshape(-1,1,h,w))
Y_3 = model3(X[:,2].reshape(-1,1,h,w))
Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
'''
Display images
'''
iii = tensor_to_pil(Y_[0])
plt.subplot(1, 2, 1)
plt.imshow(iii)
iii = tensor_to_pil(X[0])
plt.subplot(1, 2, 2)
plt.imshow(iii)
plt.show()
plt.pause(1)
#plt.ioff()
"""
For training run this command:
python face_cnn.py train
For testing fun this command:
python face_cnn.py test
"""
if __name__ == '__main__':
args = sys.argv[1:]
print( args, len(args))
if (len(args) == 1) & (args[0] == "train"):
train()
elif (len(args) == 1) & (args[0] == "run"):
run()
else:
test()
标签: DL
日历
最新微语
- 有的时候,会站在分叉路口,不知道向左还是右
2023-12-26 15:34
- 繁花乱开,鸟雀逐风。心自宁静,纷扰不闻。
2023-03-14 09:56
- 对于不可控的事,我们保持乐观,对于可控的事情,我们保持谨慎。
2023-02-09 11:03
- 小时候,
暑假意味着无忧无虑地玩很长一段时间,
节假意味着好吃好喝还有很多长期不见的小朋友来玩...
长大后,
这是女儿第一个暑假,
一个半月...
2022-07-11 08:54
- Watching the autumn leaves falling as you grow older together
2018-10-25 09:45
分类
最新评论
- Goonog
i get it now :) - 萧
@Fluzak:The web host... - Fluzak
Nice blog here! Also... - Albertarive
In my opinion you co... - ChesterHep
What does it plan? - ChesterHep
No, opposite. - mojoheadz
Everything is OK!... - Josephmaigh
I just want to say t... - ChesterHep
What good topic - AnthonyBub
Certainly, never it ...
发表评论: