苗火 Nicholas
[DL]SRGAN
2022-6-9 萧
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random

from torchvision.models import *

means = np.array([0.485, 0.456, 0.406])
stds = np.array([0.229, 0.224, 0.225])

channels = 3
pic_height = 128
pic_width = 128
upscale_factor = 2
hr_shape = (pic_height, pic_width)

batchs = 1000
batch_size = 4

learning_rate = 0.0001
beats_l = 0.9
beats_r = 0.999

"""
# errors
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/1.0) - (img2/1.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)
"""

def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/255.0) - (img2/255.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(1/math.sqrt(mse))


def rgb2ycbcr(rgb_image):
"""convert rgb into ycbcr"""
if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
return rgb_image
#raise ValueError("input image is not a rgb image")
rgb_image = rgb_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
shift_matrix = np.array([16, 128, 128])
ycbcr_image = np.zeros(shape=rgb_image.shape)
w, h, _ = rgb_image.shape
for i in range(w):
for j in range(h):
ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
return ycbcr_image

def ycbcr2rgb(ycbcr_image):
"""convert ycbcr into rgb"""
if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
return ycbcr_image
#raise ValueError("input image is not a rgb image")
ycbcr_image = ycbcr_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
transform_matrix_inv = np.linalg.inv(transform_matrix)
shift_matrix = np.array([16, 128, 128])
rgb_image = np.zeros(shape=ycbcr_image.shape)
w, h, _ = ycbcr_image.shape
for i in range(w):
for j in range(h):
rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
return rgb_image.astype(np.uint8)


def verse_normalize(img, means, stds):
means = torch.Tensor(means)
stds = torch.Tensor(stds)
for i in range(3):
img[:,i] = img[:,i] * stds[i] + means[i]
return img


class dataloader:
def __init__(self, path, batchs, batch_size, test= False, test_offset = 0.9, order = False):
self.test = test
self.test_offset = test_offset
self.batchs_cnt = 0
self.batchs = batchs
self.batch_size = batch_size
self.path = path
self.file_list = os.listdir(self.path)
if order:
self.file_list = sorted(self.file_list)
else:
random.shuffle(self.file_list)
self.file_number = len(self.file_list)

self.file_start = 0
self.file_stop = self.file_number * self.test_offset - 1
if self.test:
self.file_start = self.file_number * self.test_offset
self.file_stop = self.file_number - 1

self.file_start = int(np.floor(self.file_start))
self.file_stop = int(np.floor(self.file_stop))

self.file_idx = self.file_start

self.downsample = transforms.Compose([
transforms.Resize( (pic_height // upscale_factor, pic_width // upscale_factor), Image.BICUBIC),
#transforms.CenterCrop( pic_height),
#transforms.Scale( pic_height // 2, interpolation=Image.BICUBIC),
#transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
#transforms.RandomResizedCrop( int(pic_height/2), scale=(1,1) ),
#transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])

self.transform = transforms.Compose([
#transforms.CenterCrop( pic_height),
transforms.Resize( (pic_height,pic_width), Image.BILINEAR),
#transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])

def get_len(self):
return self.file_stop - self.file_start

def __iter__(self):
return self

def __next__(self):
if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
self.batchs_cnt = 0
raise StopIteration
self.batchs_cnt += 1

X = []
Y = []
for i in range( self.batch_size):
X_, Y_ = self._next()
X.append(X_)
Y.append(Y_)

X = torch.stack(X, 0)
Y = torch.stack(Y, 0)
return X, Y

def _next(self):
if self.file_idx >= self.file_stop:
self.file_idx = self.file_start

file_path = self.path + '/' + self.file_list[self.file_idx]
self.file_idx += 1
Y = Image.open(file_path)
#print( "Y:", file_path, " (", Y.size )

if len(Y.getbands()) == 1:
Y = Y.convert("RGB")

Y_ = self.transform(Y)
X_ = self.downsample(Y)
del Y

return X_,Y_

class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__() # in_features: 64
self.conv_block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
nn.PReLU(),
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
)

def forward(self, x):
return x + self.conv_block(x)


class GeneratorResNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
super(GeneratorResNet, self).__init__()

self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())

res_blocks = []

for _ in range(n_residual_blocks):
res_blocks.append(ResidualBlock(64))
self.res_blocks = nn.Sequential(*res_blocks)

self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))

upsampling = []
for out_features in range(1): #2
upsampling += [
# nn.Upsample(scale_factor=2),
nn.Conv2d(64, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(),
]
self.upsampling = nn.Sequential(*upsampling)

self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())

def forward(self, x):
out1 = self.conv1(x)
out2 = self.res_blocks(out1)
out3 = self.conv2(out2)
out4 = torch.add(out1, out3)
out5 = self.upsampling(out4)
out = self.conv3(out5)
return out

class Discriminator(nn.Module):
def __init__(self, input_shape): # input_shape: (3, 128, 128)
super(Discriminator, self).__init__()

self.input_shape = input_shape
in_channels, in_height, in_width = self.input_shape
patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
self.output_shape = (1, patch_h, patch_w) # patch_h: 8 patch_w: 8

def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
if not first_block:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

layers = []
in_filters = in_channels
for i, out_filters in enumerate([64, 128, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters

layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

self.model = nn.Sequential(*layers)

def forward(self, img):
return self.model(img)

class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19(pretrained=True)
# the first 18 layers of vgg16
self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

def forward(self, img):
return self.feature_extractor(img)

generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(channels, *hr_shape))
feature_extractor = FeatureExtractor()

print(generator)
print(discriminator)

feature_extractor.eval()

criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beats_l, beats_r))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beats_l, beats_r))

tensor_to_pil = ToPILImage()


'''
dataloader = dataloader('./dataset/Set14/image_SRF_2')
plt.figure("ddd")
plt.ion()
plt.cla()
for X,Y in dataloader:
plt.imshow(X)
plt.pause(0.1)
plt.ioff()

iii = tensor_to_pil(Y_[0])
plt.imshow(iii)
plt.show()


'''
dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', batchs, batch_size)
#dataloader_train = dataloader('./dataset/Set14/image_SRF_3', 10, 4)
#dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2)
#dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 4)
dataloader_test = dataloader('./images', 0, 3)
new_training = 0

Tensor = torch.Tensor

def train():
try:
generator.load_state_dict(torch.load("./model/generator.m"))
discriminator.load_state_dict(torch.load("./model/discriminator.m"))
except:
print(" ************** Model config files not accessable !!! ************")


running_loss = 0
plt.ion()

i = 0

for X,Y in dataloader_train:
'''
Train
'''
valid = Variable( Tensor(np.ones( (X.size(0), *discriminator.output_shape) ) ), requires_grad=False)
fake = Variable( Tensor(np.zeros( (X.size(0), *discriminator.output_shape) ) ), requires_grad=False)

optimizer_G.zero_grad()
gen_Y = generator( X )
loss_GAN = criterion_GAN( discriminator( gen_Y ), valid)

features_gen_Y = feature_extractor( gen_Y )
features_Y = feature_extractor( Y )
loss_content = criterion_content( features_gen_Y, features_Y.detach())

loss_G = loss_content + 1e-3 * loss_GAN

loss_G.backward()
optimizer_G.step()


optimizer_D.zero_grad()
loss_valid = criterion_GAN( discriminator( Y ), valid)
loss_fake = criterion_GAN( discriminator( gen_Y.detach() ), fake)

loss_D = ( loss_valid + loss_fake ) / 2

loss_D.backward()
optimizer_D.step()


print( "Epoch %d/%d: D loss: %f, G loss: %f"
% ( i, batchs, loss_D.item(), loss_G.item() ) + '\n')

#display_img( gen_Y[0], Y[0], X[0])

i += 1
if i%5 == 0:
test()
torch.save(generator.state_dict(), "./model/generator.m")
torch.save(discriminator.state_dict(), "./model/discriminator.m")

plt.ioff()


def display_img(img1, img2, img3 = None):
n = 3
if img3 is None:
n = 2

plt.subplot(1, n, 1)
plt.imshow(img1)

plt.subplot(1, n, 2)
plt.imshow(img2)

if img3 is not None:
plt.subplot(1, n, 3)
plt.imshow(img3)

plt.show()

def test():
running_loss = 0

#for X,Y in dataloader_test:
for i in range(1):
X,Y = next(dataloader_test)
print(X.shape)

_,lc,lh,lw = X.shape
_,hc,hh,hw = Y.shape

Y_ = generator( X )

print(" Psnr = ", psnr(Y_,Y) )

Y_ = verse_normalize(Y_, means, stds)
Y = verse_normalize(Y, means, stds)
X = verse_normalize(X, means, stds)
'''
Display images
'''
display_img( tensor_to_pil(Y_[0]), tensor_to_pil(Y[0]), tensor_to_pil(X[0]) )

plt.pause(0.1)

def run():
generator.load_state_dict(torch.load("./model/generator.m"))
discriminator.load_state_dict(torch.load("./model/discriminator.m"))

generator.eval()
discriminator.eval()

path = "./dataset/Set5/image_SRF_2"

file_list = os.listdir(path)
#plt.ion()


for img_name in file_list:
img = Image.open(path+'/'+img_name)
if len(img.getbands()) == 1:
img = img.convert("RGB")
c = 3

# turn ( h X w ) to ( uh X uw )
w, h = img.size
uh = h * upscale_factor
uw = w * upscale_factor

# block size
lh = pic_height // upscale_factor
lw = pic_width // upscale_factor

# blocks number
rows = math.ceil( h / lh )
cols = math.ceil( w / lw )

# tmp img size
h_mid = rows * lh
w_mid = cols * lw
uh_mid = h_mid * upscale_factor
uw_mid = w_mid * upscale_factor

transform_mid = transforms.Compose([
transforms.CenterCrop( (h_mid, w_mid) ),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
transform_last = transforms.Compose([
transforms.CenterCrop( (uh, uw) ),
transforms.ToTensor()
])

Y = torch.zeros( (1, 3, uh_mid, uw_mid) )
img_ = transform_mid(img)
row = 0
col = 0
for row in range(rows):
for col in range(cols):
X_ = img_[:, row * lh : (row + 1) * lh, col * lw : (col + 1) * lw].detach()
X_ = X_.reshape(-1, c, lh, lw)
Y_ = generator(X_)
Y[0, :, row * pic_height: (row+1) * pic_height, col * pic_width: (col+1) * pic_width] = Y_

Y = verse_normalize(Y, means, stds)
Y = Y[0]
Y = tensor_to_pil(Y)
Y = transform_last(Y)
Y = tensor_to_pil(Y)
Y.save("./SRGAN.png")
display_img( Y, img)

plt.pause(1)
plt.clf()

#plt.ioff()

"""

For training run this command:
python face_cnn.py train

For testing fun this command:
python face_cnn.py test

"""
if __name__ == '__main__':
args = sys.argv[1:]
print( args, len(args))
if (len(args) == 1) & (args[0] == "train"):
train()
elif (len(args) == 1) & (args[0] == "run"):
run()
else:
test()


发表评论:
昵称

邮件地址 (选填)

个人主页 (选填)

内容