[DL]VDSR:Super-resolution with VDSR

2022-4-28 写技术

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random

means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]

pic_height = 512
pic_width = 512


"""
# errors
def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/1.0) - (img2/1.0)   ) ** 2 
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(255.0**2/mse)
"""

def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/255.0) - (img2/255.0)   ) ** 2 
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(1/math.sqrt(mse))


def rgb2ycbcr(rgb_image):
    """convert rgb into ycbcr"""
    if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
        return rgb_image
        #raise ValueError("input image is not a rgb image")
    rgb_image = rgb_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    shift_matrix = np.array([16, 128, 128])
    ycbcr_image = np.zeros(shape=rgb_image.shape)
    w, h, _ = rgb_image.shape
    for i in range(w):
        for j in range(h):
            ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
    return ycbcr_image

def ycbcr2rgb(ycbcr_image):
    """convert ycbcr into rgb"""
    if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
        return ycbcr_image
        #raise ValueError("input image is not a rgb image")
    ycbcr_image = ycbcr_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    transform_matrix_inv = np.linalg.inv(transform_matrix)
    shift_matrix = np.array([16, 128, 128])
    rgb_image = np.zeros(shape=ycbcr_image.shape)
    w, h, _ = ycbcr_image.shape
    for i in range(w):
        for j in range(h):
            rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
    return rgb_image.astype(np.uint8)

class dataloader:
    def __init__(self, path, batchs,  batch_size, test= False, test_offset = 0.9, order = False):
        self.test = test
        self.test_offset = test_offset
        self.batchs_cnt = 0
        self.batchs = batchs
        self.batch_size = batch_size
        self.path = path
        self.file_list = os.listdir(self.path)
        if order:
            self.file_list = sorted(self.file_list)
        else:
            random.shuffle(self.file_list)
        self.file_number = len(self.file_list)

        self.file_start = 0
        self.file_stop = self.file_number * self.test_offset - 1
        if self.test:
            self.file_start = self.file_number * self.test_offset
            self.file_stop = self.file_number - 1

        self.file_start = int(np.floor(self.file_start))
        self.file_stop = int(np.floor(self.file_stop))

        self.file_idx = self.file_start

        self.downsample = transforms.Compose([
            transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
            #transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            #transforms.ToTensor()
            #transforms.Normalize(means, stds)
            ])

        self.transform = transforms.Compose([
            #transforms.Resize( (pic_height,pic_width), Image.BILINEAR),
            transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor()
            #transforms.Normalize(means, stds)
            ])

    def get_len(self):
        return self.file_stop  - self.file_start

    def __iter__(self):
        return self

    def __next__(self):
        if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
            self.batchs_cnt = 0
            raise StopIteration
        self.batchs_cnt += 1

        X = []
        Y = []
        for i in range( self.batch_size):
            X_, Y_ = self._next()
            X.append(X_)
            Y.append(Y_)

        X = torch.stack(X, 0)
        Y = torch.stack(Y, 0)
        return X, Y

    def _next(self):
        if self.file_idx >= self.file_stop:
            self.file_idx = self.file_start

        file_path = self.path + '/' + self.file_list[self.file_idx]
        self.file_idx += 1
        Y = Image.open(file_path)
        #print( "Y:", file_path, " (", Y.size )

        if len(Y.getbands()) == 1:
            Y = Y.convert("RGB")

        #X = self.downsample(Y)
        X = Y.filter(ImageFilter.BLUR)

        X = self.transform(X)
        Y = self.transform(Y)

        return X,Y

class VDSR(nn.Module):
    def __init__(self):
        super(VDSR, self).__init__()
        """
        There's 20 layers in original paper.
        """
        self.cnn = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,1,kernel_size=3,stride=1,padding=1),
        )

    def forward(self, x):
        x = self.cnn(x)
        return x

    def init_weights(self):
        print(" ****************************")

tensor_to_pil = ToPILImage()

model1 = VDSR()
model2 = VDSR()
model3 = VDSR()

'''
dataloader = dataloader('./dataset/Set14/image_SRF_2')
plt.figure("ddd")
plt.ion()
plt.cla()
for X,Y in dataloader:
    plt.imshow(X)
    plt.pause(0.1)
plt.ioff()

iii = tensor_to_pil(Y_[0])
plt.imshow(iii)
plt.show()


'''
dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', 200, 1)
#dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2)
dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 1)
new_training = 1

def train():
    global model1
    global model2
    global model3

    if new_training == 1:
        print(" new training ...")
        model1.init_weights()
        model2.init_weights()
        model3.init_weights()
    else:
        print(" continue training ...")
        del model1
        del model2
        del model3

        model1= torch.load("./model/srcnn_l1.m")
        model2= torch.load("./model/srcnn_l2.m")
        model3= torch.load("./model/srcnn_l3.m")

    loss = nn.MSELoss()

    optimizer1 = torch.optim.Adam(model1.parameters(), lr = 0.00001, betas = (0.9, 0.999))
    optimizer2 = torch.optim.Adam(model2.parameters(), lr = 0.00001, betas = (0.9, 0.999))
    optimizer3 = torch.optim.Adam(model3.parameters(), lr = 0.00001, betas = (0.9, 0.999))


    running_loss = 0
    plt.ion()

    i = 0

    for X,Y in dataloader_train:
        #Residual
        R = Y - X
        del Y
    
        '''
        Train
        '''
        model1.train()
        model2.train()
        model3.train()

        _,c,h,w = X.shape

        Y_1 = model1(X[:,0].reshape(-1,1,h,w))
        optimizer1.zero_grad()
        cost1 = loss(Y_1, R[:,0].reshape(-1,1,h,w))
        cost1.backward()
        optimizer1.step()
        del Y_1

        Y_2 = model2(X[:,1].reshape(-1,1,h,w))
        optimizer2.zero_grad()
        cost2 = loss(Y_2, R[:,1].reshape(-1,1,h,w))
        cost2.backward()
        optimizer2.step()
        del Y_2

        Y_3 = model3(X[:,2].reshape(-1,1,h,w))
        optimizer3.zero_grad()
        cost3 = loss(Y_3, R[:,2].reshape(-1,1,h,w))
        cost3.backward()
        optimizer3.step()
        del Y_3

        del X
        del R


        #running_loss += cost.item()
        print( "batch:{}, loss  is {:.4f} {:.4f} {:.4f} ".format(i, cost1, cost2, cost3) )

        i += 1
        if i%10 == 0:
            test()

    plt.ioff()
    torch.save(model1, "./model/srcnn_l1.m")
    torch.save(model2, "./model/srcnn_l2.m")
    torch.save(model3, "./model/srcnn_l3.m")

def test():
    global model1
    global model2
    global model3

    running_loss = 0
    #plt.ion()

    for X,Y in dataloader_test:


        print(X.shape)
        model1.eval()
        model2.eval()
        model3.eval()

        _,c,h,w = X.shape

        Y_1 = model1(X[:,0].reshape(-1,1,h,w))
        Y_2 = model2(X[:,1].reshape(-1,1,h,w))
        Y_3 = model3(X[:,2].reshape(-1,1,h,w))

        print( c, h, w)
        R = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
        del Y_1
        del Y_2
        del Y_3

        #Get img with redidual
        Y_ = R + X
        del R

        print(" Psnr = ", psnr(Y_,Y), "  psnr2 = ", psnr(X, Y) )
        del Y

        '''
        Display images
        '''
        iii = tensor_to_pil(Y_[0])
        del Y_
        plt.subplot(1, 2, 1)
        plt.imshow(iii)
        del iii

        iii = tensor_to_pil(X[0])
        del X
        plt.subplot(1, 2, 2)
        plt.imshow(iii)
        del iii

       
        plt.show()
        plt.pause(0.1)

    #plt.ioff()



def run():
    del model1
    del model2
    del model3

    model1= torch.load("./model/srcnn_l1.m")
    model2= torch.load("./model/srcnn_l2.m")
    model3= torch.load("./model/srcnn_l3.m")

    model1.eval()
    model2.eval()
    model3.eval()
    path = "./images"

    file_list = os.listdir(path)
    #plt.ion()

    transform = transforms.Compose([
        transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
        transforms.ToTensor()
            ])
 
    for img_name in file_list:
        img = Image.open(path+'/'+img_name)
        if len(img.getbands()) == 1:
            img = img.convert("RGB")
        X_ = transform(img)
        del img

        c,h,w = X_.shape
        X = X_.reshape(1,c,h,w)
        del X_

        Y_1 = model1(X[:,0].reshape(-1,1,h,w))
        Y_2 = model2(X[:,1].reshape(-1,1,h,w))
        Y_3 = model3(X[:,2].reshape(-1,1,h,w))

        Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)
        del Y_1
        del Y_2
        del Y_3

        #Get img with redidual
        Y = Y_ + X
        del Y_

        '''
        Display images
        '''

        Y_ = tensor_to_pil(Y[0])
        del Y
        plt.subplot(1, 2, 1)
        plt.imshow(Y_)
        del Y_
        
        Y_ = tensor_to_pil(X[0])
        del X
        plt.subplot(1, 2, 2)
        plt.imshow(Y_)
        del Y_

        plt.show()
        plt.pause(1)
        plt.clf()

    #plt.ioff()

"""

For training run this command:
python face_cnn.py train

For testing fun this command:
python face_cnn.py test

"""
if __name__ == '__main__':
    args = sys.argv[1:]
    print( args, len(args))
    if (len(args) == 1) & (args[0] == "train"):
        train()
    elif (len(args) == 1) & (args[0] == "run"):
        run()
    else:
        test()


标签: python DL

发表评论:

Powered by anycle 湘ICP备15001973号-1