苗火 Nicholas
[DL]ESPCN
2022-6-9 萧
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random

means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]

pic_height = 512
pic_width = 512
upscale_factor = 2

opt = {
"pic_c":32, # middle convolution layers is 32
"pic_h":pic_height, # low responsity images' height is 48
"pic_w":pic_width,
"pic_up":2, # upscale
"heads":8,
"vec_dim":512,
"N":6,
"x_vocab_len":0,
"y_vocab_len":0,
"sentence_len":80,
"batchs":1000,
"batch_size":10,
"pad":0,
"sof":1,
"eof":2,
"cat_padding": 5,
}

def verse_normalize(img, means, stds):
means = torch.Tensor(means)
stds = torch.Tensor(stds)
for i in range(3):
img[:,i] = img[:,i] * stds[i] + means[i]
return img


"""
# errors
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/1.0) - (img2/1.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)
"""

def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/255.0) - (img2/255.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(1/math.sqrt(mse))


def rgb2ycbcr(rgb_image):
"""convert rgb into ycbcr"""
if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
return rgb_image
#raise ValueError("input image is not a rgb image")
rgb_image = rgb_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
shift_matrix = np.array([16, 128, 128])
ycbcr_image = np.zeros(shape=rgb_image.shape)
w, h, _ = rgb_image.shape
for i in range(w):
for j in range(h):
ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
return ycbcr_image

def ycbcr2rgb(ycbcr_image):
"""convert ycbcr into rgb"""
if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
return ycbcr_image
#raise ValueError("input image is not a rgb image")
ycbcr_image = ycbcr_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
transform_matrix_inv = np.linalg.inv(transform_matrix)
shift_matrix = np.array([16, 128, 128])
rgb_image = np.zeros(shape=ycbcr_image.shape)
w, h, _ = ycbcr_image.shape
for i in range(w):
for j in range(h):
rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
return rgb_image.astype(np.uint8)

class dataloader:
def __init__(self, path, batchs, batch_size, test= False, test_offset = 0.9, order = False):
self.test = test
self.test_offset = test_offset
self.batchs_cnt = 0
self.batchs = batchs
self.batch_size = batch_size
self.path = path
self.file_list = os.listdir(self.path)
if order:
self.file_list = sorted(self.file_list)
else:
random.shuffle(self.file_list)
self.file_number = len(self.file_list)

self.file_start = 0
self.file_stop = self.file_number * self.test_offset - 1
if self.test:
self.file_start = self.file_number * self.test_offset
self.file_stop = self.file_number - 1

self.file_start = int(np.floor(self.file_start))
self.file_stop = int(np.floor(self.file_stop))

self.file_idx = self.file_start

self.downsample = transforms.Compose([
transforms.CenterCrop( pic_height),
transforms.Scale( pic_height // 2, interpolation=Image.BICUBIC),
#transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
#transforms.RandomResizedCrop( int(pic_height/2), scale=(1,1) ),
#transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor()
#transforms.Normalize(means, stds)
])

self.transform = transforms.Compose([
transforms.CenterCrop( pic_height),
#transforms.Resize( (pic_height,pic_width), Image.BILINEAR),
#transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor()
#transforms.Normalize(means, stds)
])

def get_len(self):
return self.file_stop - self.file_start

def __iter__(self):
return self

def __next__(self):
if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
self.batchs_cnt = 0
raise StopIteration
self.batchs_cnt += 1

X = []
Y = []
for i in range( self.batch_size):
X_, Y_ = self._next()
X.append(X_)
Y.append(Y_)

X = torch.stack(X, 0)
Y = torch.stack(Y, 0)
return X, Y

def _next(self):
if self.file_idx >= self.file_stop:
self.file_idx = self.file_start

file_path = self.path + '/' + self.file_list[self.file_idx]
self.file_idx += 1
Y = Image.open(file_path)
#print( "Y:", file_path, " (", Y.size )

if len(Y.getbands()) == 1:
Y = Y.convert("RGB")

Y_ = self.transform(Y)
X_ = self.downsample(Y)
del Y

return X_,Y_

class ESPCN(nn.Module):
def __init__(self, upscale_factor):
super(ESPCN, self).__init__()
"""
There's 20 layers in original paper.
"""
self.conv1 = nn.Conv2d(1,64,kernel_size=5,stride=1,padding=2)
self.conv2 = nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1)
self.conv3 = nn.Conv2d(64,32,kernel_size=3,stride=1,padding=1)
self.conv4 = nn.Conv2d(32,1* ( upscale_factor ** 2), kernel_size=3,stride=1,padding=1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

def forward(self, x):
x = torch.tanh(self.conv1(x))
x = torch.tanh(self.conv2(x))
x = torch.tanh(self.conv3(x))
x = torch.sigmoid( self.pixel_shuffle( self.conv4(x) ) )
return x

def init_weights(self):
print(" ****************************")

tensor_to_pil = ToPILImage()

model1 = ESPCN(upscale_factor)
model2 = ESPCN(upscale_factor)
model3 = ESPCN(upscale_factor)

'''
dataloader = dataloader('./dataset/Set14/image_SRF_2')
plt.figure("ddd")
plt.ion()
plt.cla()
for X,Y in dataloader:
plt.imshow(X)
plt.pause(0.1)
plt.ioff()

iii = tensor_to_pil(Y_[0])
plt.imshow(iii)
plt.show()


'''
dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', 2000, 4)
#dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2)
dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 4)
new_training = 0

def display_img(img1, img2):
plt.subplot(1, 2, 1)
plt.imshow(img1)

plt.subplot(1, 2, 2)
plt.imshow(img2)

plt.show()


def train():
global model1
global model2
global model3

if new_training == 1:
print(" new training ...")
model1.init_weights()
model2.init_weights()
model3.init_weights()
else:
print(" continue training ...")
model1= torch.load("./model/srcnn_l1.m")
model2= torch.load("./model/srcnn_l2.m")
model3= torch.load("./model/srcnn_l3.m")

loss = nn.MSELoss()

optimizer1 = torch.optim.Adam(model1.parameters(), lr = 0.00001, betas = (0.9, 0.999))
optimizer2 = torch.optim.Adam(model2.parameters(), lr = 0.00001, betas = (0.9, 0.999))
optimizer3 = torch.optim.Adam(model3.parameters(), lr = 0.00001, betas = (0.9, 0.999))


running_loss = 0
plt.ion()

i = 0

for X,Y in dataloader_train:

'''
Train
'''
model1.train()
model2.train()
model3.train()

_,lc,lh,lw = X.shape
_,hc,hh,hw = Y.shape

Y_1 = model1(X[:,0].reshape(-1,1,lh,lw))
optimizer1.zero_grad()
cost1 = loss(Y_1, Y[:,0].reshape(-1,1,hh,hw))
cost1.backward()
optimizer1.step()
del Y_1

Y_2 = model2(X[:,1].reshape(-1,1,lh,lw))
optimizer2.zero_grad()
cost2 = loss(Y_2, Y[:,1].reshape(-1,1,hh,hw))
cost2.backward()
optimizer2.step()
del Y_2

Y_3 = model3(X[:,2].reshape(-1,1,lh,lw))
optimizer3.zero_grad()
cost3 = loss(Y_3, Y[:,2].reshape(-1,1,hh,hw))
cost3.backward()
optimizer3.step()
del Y_3

del X

#running_loss += cost.item()
print( "batch:{}, loss is {:.4f} {:.4f} {:.4f} ".format(i, cost1, cost2, cost3) )

i += 1
if i%10 == 0:
test()

plt.ioff()
torch.save(model1, "./model/srcnn_l1.m")
torch.save(model2, "./model/srcnn_l2.m")
torch.save(model3, "./model/srcnn_l3.m")

def test():
global model1
global model2
global model3

running_loss = 0
#plt.ion()

for X,Y in dataloader_test:


print(X.shape)
model1.eval()
model2.eval()
model3.eval()

_,lc,lh,lw = X.shape
_,hc,hh,hw = Y.shape

Y_1 = model1(X[:,0].reshape(-1,1,lh,lw))
Y_2 = model2(X[:,1].reshape(-1,1,lh,lw))
Y_3 = model3(X[:,2].reshape(-1,1,lh,lw))

Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
del Y_1
del Y_2
del Y_3

print(" Psnr = ", psnr(Y_,Y) )
del Y

'''
Display images
'''
iii = tensor_to_pil(Y_[0])
del Y_
plt.subplot(1, 2, 1)
plt.imshow(iii)
del iii

iii = tensor_to_pil(X[0])
del X
plt.subplot(1, 2, 2)
plt.imshow(iii)
del iii


plt.show()
plt.pause(0.1)

#plt.ioff()



def run():
model1= torch.load("./model/srcnn_l1.m")
model2= torch.load("./model/srcnn_l2.m")
model3= torch.load("./model/srcnn_l3.m")

model1.eval()
model2.eval()
model3.eval()
path = "./dataset/Set5/image_SRF_2"

file_list = os.listdir(path)
#plt.ion()


# block size of SR images
lh = opt['pic_h'] // opt['pic_up']
lw = opt['pic_w'] // opt['pic_up']

# the stride of moving blocks processing
stride_h = lh - opt['cat_padding'] * 2
stride_w = lw - opt['cat_padding'] * 2

# the center sub block of model outputings
y_top = (opt['cat_padding']) * opt['pic_up']
y_left = (opt['cat_padding']) * opt['pic_up']
y_down = (lh - opt['cat_padding']) * opt['pic_up']
y_right = (lw - opt['cat_padding']) * opt['pic_up']

for img_name in file_list:
# get images
img = Image.open(path+'/'+img_name)
if len(img.getbands()) == 1:
img = img.convert("RGB")
c = 3

w,h = img.size
org_img = img
org_w, org_h = w,h
dst_w, dst_h = w * opt['pic_up'], h * opt['pic_up']

# padding the edge, turn ( h X w ) to ( uh X uw )
pad_h = h + opt['cat_padding'] * 2
pad_w = w + opt['cat_padding'] * 2
pad_h_off = ((pad_h - 2*opt['cat_padding']) % stride_h)
pad_w_off = ((pad_w - 2*opt['cat_padding']) % stride_w)

if pad_h_off != 0:
pad_h += stride_h - pad_h_off
if pad_w_off != 0:
pad_w += stride_w - pad_w_off
transform_pad = transforms.Compose([
transforms.CenterCrop( (pad_h, pad_w) ),
])
img = transform_pad(img)
w,h = img.size

# size of SR images
sr_h = h * opt['pic_up']
sr_w = w * opt['pic_up']

transform_mid = transforms.Compose([
transforms.CenterCrop( (h, w) ),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
transform_last = transforms.Compose([
transforms.CenterCrop( (dst_h, dst_w) ),
transforms.ToTensor(),
#transforms.Normalize(means, stds), as the Y_ has been verse normalized.
])

Y = torch.zeros( (1, 3, sr_h, sr_w) )
img_ = transform_mid(img)

top = 0
down = top + lh
while down <= h:
left = 0
right = left + lw
while right <= w:

X_ = img_[:, top : down, left : right].detach()
X_ = X_.reshape(-1, c, lh, lw)

Y_1 = model1(X_[:,0].reshape(-1,1,lh,lw))
Y_2 = model2(X_[:,1].reshape(-1,1,lh,lw))
Y_3 = model3(X_[:,2].reshape(-1,1,lh,lw))

Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)

cent_top = (top + opt['cat_padding']) * opt['pic_up']
cent_left = (left + opt['cat_padding']) * opt['pic_up']
cent_down = (down - opt['cat_padding']) * opt['pic_up']
cent_right = (right - opt['cat_padding']) * opt['pic_up']

Y[0, :, cent_top : cent_down, cent_left: cent_right] = Y_[:, :, y_top: y_down, y_left: y_right]

# move patchs to right
left += stride_w
right = left + lw
# move patchs to down
top += stride_h
down = top + lh


Y = verse_normalize(Y, means, stds)
Y = Y[0]
Y = tensor_to_pil(Y)
Y = transform_last(Y)
print(" Psnr = ", psnr(Y, transform_last(org_img)) )
Y = tensor_to_pil(Y)
Y.save("./ESPCN.png")
display_img( Y, org_img)

plt.pause(1)
plt.clf()


"""

For training run this command:
python face_cnn.py train

For testing fun this command:
python face_cnn.py test

"""
if __name__ == '__main__':
args = sys.argv[1:]
print( args, len(args))
if (len(args) == 1) & (args[0] == "train"):
train()
elif (len(args) == 1) & (args[0] == "run"):
run()
else:
test()


发表评论:
昵称

邮件地址 (选填)

个人主页 (选填)

内容