import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random
means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]
pic_height = 512
pic_width = 512
"""
# errors
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/1.0) - (img2/1.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)
"""
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/255.0) - (img2/255.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(1/math.sqrt(mse))
def rgb2ycbcr(rgb_image):
"""convert rgb into ycbcr"""
if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
return rgb_image
#raise ValueError("input image is not a rgb image")
rgb_image = rgb_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
shift_matrix = np.array([16, 128, 128])
ycbcr_image = np.zeros(shape=rgb_image.shape)
w, h, _ = rgb_image.shape
for i in range(w):
for j in range(h):
ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
return ycbcr_image
def ycbcr2rgb(ycbcr_image):
"""convert ycbcr into rgb"""
if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
return ycbcr_image
#raise ValueError("input image is not a rgb image")
ycbcr_image = ycbcr_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
transform_matrix_inv = np.linalg.inv(transform_matrix)
shift_matrix = np.array([16, 128, 128])
rgb_image = np.zeros(shape=ycbcr_image.shape)
w, h, _ = ycbcr_image.shape
for i in range(w):
for j in range(h):
rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
return rgb_image.astype(np.uint8)
class dataloader:
def __init__(self, path, batchs, batch_size, test= False, test_offset = 0.9, order = False):
self.test = test
self.test_offset = test_offset
self.batchs_cnt = 0
self.batchs = batchs
self.batch_size = batch_size
self.path = path
self.file_list = os.listdir(self.path)
if order:
self.file_list = sorted(self.file_list)
else:
random.shuffle(self.file_list)
self.file_number = len(self.file_list)
self.file_start = 0
self.file_stop = self.file_number * self.test_offset - 1
if self.test:
self.file_start = self.file_number * self.test_offset
self.file_stop = self.file_number - 1
self.file_start = int(np.floor(self.file_start))
self.file_stop = int(np.floor(self.file_stop))
self.file_idx = self.file_start
self.downsample = transforms.Compose([
transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
#transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
#transforms.Grayscale(num_output_channels = 3),
#transforms.ToTensor()
#transforms.Normalize(means, stds)
])
self.transform = transforms.Compose([
#transforms.Resize( (pic_height,pic_width), Image.BILINEAR),
transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor()
#transforms.Normalize(means, stds)
])
def get_len(self):
return self.file_stop - self.file_start
def __iter__(self):
return self
def __next__(self):
if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
self.batchs_cnt = 0
raise StopIteration
self.batchs_cnt += 1
X = []
Y = []
for i in range( self.batch_size):
X_, Y_ = self._next()
X.append(X_)
Y.append(Y_)
X = torch.stack(X, 0)
Y = torch.stack(Y, 0)
return X, Y
def _next(self):
if self.file_idx >= self.file_stop:
self.file_idx = self.file_start
file_path = self.path + '/' + self.file_list[self.file_idx]
self.file_idx += 1
Y = Image.open(file_path)
#print( "Y:", file_path, " (", Y.size )
if len(Y.getbands()) == 1:
Y = Y.convert("RGB")
#X = self.downsample(Y)
X = Y.filter(ImageFilter.BLUR)
X = self.transform(X)
Y = self.transform(Y)
return X,Y
class VDSR(nn.Module):
def __init__(self):
super(VDSR, self).__init__()
"""
There's 20 layers in original paper.
"""
self.cnn = nn.Sequential(
nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.ReLU(),
nn.Conv2d(64,1,kernel_size=3,stride=1,padding=1),
)
def forward(self, x):
x = self.cnn(x)
return x
def init_weights(self):
print(" ****************************")
tensor_to_pil = ToPILImage()
model1 = VDSR()
model2 = VDSR()
model3 = VDSR()
'''
dataloader = dataloader('./dataset/Set14/image_SRF_2')
plt.figure("ddd")
plt.ion()
plt.cla()
for X,Y in dataloader:
plt.imshow(X)
plt.pause(0.1)
plt.ioff()
iii = tensor_to_pil(Y_[0])
plt.imshow(iii)
plt.show()
'''
dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', 200, 1)
#dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2)
dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 1)
new_training = 1
def train():
global model1
global model2
global model3
if new_training == 1:
print(" new training ...")
model1.init_weights()
model2.init_weights()
model3.init_weights()
else:
print(" continue training ...")
del model1
del model2
del model3
model1= torch.load("./model/srcnn_l1.m")
model2= torch.load("./model/srcnn_l2.m")
model3= torch.load("./model/srcnn_l3.m")
loss = nn.MSELoss()
optimizer1 = torch.optim.Adam(model1.parameters(), lr = 0.00001, betas = (0.9, 0.999))
optimizer2 = torch.optim.Adam(model2.parameters(), lr = 0.00001, betas = (0.9, 0.999))
optimizer3 = torch.optim.Adam(model3.parameters(), lr = 0.00001, betas = (0.9, 0.999))
running_loss = 0
plt.ion()
i = 0
for X,Y in dataloader_train:
#Residual
R = Y - X
del Y
'''
Train
'''
model1.train()
model2.train()
model3.train()
_,c,h,w = X.shape
Y_1 = model1(X[:,0].reshape(-1,1,h,w))
optimizer1.zero_grad()
cost1 = loss(Y_1, R[:,0].reshape(-1,1,h,w))
cost1.backward()
optimizer1.step()
del Y_1
Y_2 = model2(X[:,1].reshape(-1,1,h,w))
optimizer2.zero_grad()
cost2 = loss(Y_2, R[:,1].reshape(-1,1,h,w))
cost2.backward()
optimizer2.step()
del Y_2
Y_3 = model3(X[:,2].reshape(-1,1,h,w))
optimizer3.zero_grad()
cost3 = loss(Y_3, R[:,2].reshape(-1,1,h,w))
cost3.backward()
optimizer3.step()
del Y_3
del X
del R
#running_loss += cost.item()
print( "batch:{}, loss is {:.4f} {:.4f} {:.4f} ".format(i, cost1, cost2, cost3) )
i += 1
if i%10 == 0:
test()
plt.ioff()
torch.save(model1, "./model/srcnn_l1.m")
torch.save(model2, "./model/srcnn_l2.m")
torch.save(model3, "./model/srcnn_l3.m")
def test():
global model1
global model2
global model3
running_loss = 0
#plt.ion()
for X,Y in dataloader_test:
print(X.shape)
model1.eval()
model2.eval()
model3.eval()
_,c,h,w = X.shape
Y_1 = model1(X[:,0].reshape(-1,1,h,w))
Y_2 = model2(X[:,1].reshape(-1,1,h,w))
Y_3 = model3(X[:,2].reshape(-1,1,h,w))
print( c, h, w)
R = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
del Y_1
del Y_2
del Y_3
#Get img with redidual
Y_ = R + X
del R
print(" Psnr = ", psnr(Y_,Y), " psnr2 = ", psnr(X, Y) )
del Y
'''
Display images
'''
iii = tensor_to_pil(Y_[0])
del Y_
plt.subplot(1, 2, 1)
plt.imshow(iii)
del iii
iii = tensor_to_pil(X[0])
del X
plt.subplot(1, 2, 2)
plt.imshow(iii)
del iii
plt.show()
plt.pause(0.1)
#plt.ioff()
def run():
del model1
del model2
del model3
model1= torch.load("./model/srcnn_l1.m")
model2= torch.load("./model/srcnn_l2.m")
model3= torch.load("./model/srcnn_l3.m")
model1.eval()
model2.eval()
model3.eval()
path = "./images"
file_list = os.listdir(path)
#plt.ion()
transform = transforms.Compose([
transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
transforms.ToTensor()
])
for img_name in file_list:
img = Image.open(path+'/'+img_name)
if len(img.getbands()) == 1:
img = img.convert("RGB")
X_ = transform(img)
del img
c,h,w = X_.shape
X = X_.reshape(1,c,h,w)
del X_
Y_1 = model1(X[:,0].reshape(-1,1,h,w))
Y_2 = model2(X[:,1].reshape(-1,1,h,w))
Y_3 = model3(X[:,2].reshape(-1,1,h,w))
Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
del Y_1
del Y_2
del Y_3
#Get img with redidual
Y = Y_ + X
del Y_
'''
Display images
'''
Y_ = tensor_to_pil(Y[0])
del Y
plt.subplot(1, 2, 1)
plt.imshow(Y_)
del Y_
Y_ = tensor_to_pil(X[0])
del X
plt.subplot(1, 2, 2)
plt.imshow(Y_)
del Y_
plt.show()
plt.pause(1)
plt.clf()
#plt.ioff()
"""
For training run this command:
python face_cnn.py train
For testing fun this command:
python face_cnn.py test
"""
if __name__ == '__main__':
args = sys.argv[1:]
print( args, len(args))
if (len(args) == 1) & (args[0] == "train"):
train()
elif (len(args) == 1) & (args[0] == "run"):
run()
else:
test()