[DL]DRRN

2022-6-9 写技术

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random

means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]

pic_height = 512
pic_width = 512

opt = {
    "pic_c":32, # middle convolution layers is 32
    "pic_h":pic_height, # low responsity images' height is 48
    "pic_w":pic_width,
    "pic_up":2, # upscale
    "heads":8,
    "vec_dim":512,
    "N":6,
    "x_vocab_len":0,
    "y_vocab_len":0,
    "sentence_len":80,
    "batchs":1000,
    "batch_size":10,
    "pad":0,
    "sof":1,
    "eof":2,
    "cat_padding": 5,
}


def verse_normalize(img, means, stds):
    means = torch.Tensor(means)
    stds = torch.Tensor(stds)
    for i in range(3):
        img[:,i] = img[:,i] * stds[i] + means[i]
    return img

"""
# errors
def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/1.0) - (img2/1.0)   ) ** 2 
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(255.0**2/mse)
"""

def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/255.0) - (img2/255.0)   ) ** 2 
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(1/math.sqrt(mse))


def rgb2ycbcr(rgb_image):
    """convert rgb into ycbcr"""
    if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
        return rgb_image
        #raise ValueError("input image is not a rgb image")
    rgb_image = rgb_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    shift_matrix = np.array([16, 128, 128])
    ycbcr_image = np.zeros(shape=rgb_image.shape)
    w, h, _ = rgb_image.shape
    for i in range(w):
        for j in range(h):
            ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
    return ycbcr_image

def ycbcr2rgb(ycbcr_image):
    """convert ycbcr into rgb"""
    if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
        return ycbcr_image
        #raise ValueError("input image is not a rgb image")
    ycbcr_image = ycbcr_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    transform_matrix_inv = np.linalg.inv(transform_matrix)
    shift_matrix = np.array([16, 128, 128])
    rgb_image = np.zeros(shape=ycbcr_image.shape)
    w, h, _ = ycbcr_image.shape
    for i in range(w):
        for j in range(h):
            rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
    return rgb_image.astype(np.uint8)

class dataloader:
    def __init__(self, path, batchs,  batch_size, test= False, test_offset = 0.9, order = False):
        self.test = test
        self.test_offset = test_offset
        self.batchs_cnt = 0
        self.batchs = batchs
        self.batch_size = batch_size
        self.path = path
        self.file_list = os.listdir(self.path)
        if order:
            self.file_list = sorted(self.file_list)
        else:
            random.shuffle(self.file_list)
        self.file_number = len(self.file_list)

        self.file_start = 0
        self.file_stop = self.file_number * self.test_offset - 1
        if self.test:
            self.file_start = self.file_number * self.test_offset
            self.file_stop = self.file_number - 1

        self.file_start = int(np.floor(self.file_start))
        self.file_stop = int(np.floor(self.file_stop))

        self.file_idx = self.file_start

        self.downsample = transforms.Compose([
            transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
            #transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            #transforms.ToTensor()
            #transforms.Normalize(means, stds)
            ])

        self.transform = transforms.Compose([
            #transforms.Resize( (pic_height,pic_width), Image.BILINEAR),
            transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor()
            #transforms.Normalize(means, stds)
            ])

    def get_len(self):
        return self.file_stop  - self.file_start

    def __iter__(self):
        return self

    def __next__(self):
        if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
            self.batchs_cnt = 0
            raise StopIteration
        self.batchs_cnt += 1

        X = []
        Y = []
        for i in range( self.batch_size):
            X_, Y_ = self._next()
            X.append(X_)
            Y.append(Y_)

        X = torch.stack(X, 0)
        Y = torch.stack(Y, 0)
        return X, Y

    def _next(self):
        if self.file_idx >= self.file_stop:
            self.file_idx = self.file_start

        file_path = self.path + '/' + self.file_list[self.file_idx]
        self.file_idx += 1
        Y = Image.open(file_path)
        #print( "Y:", file_path, " (", Y.size )

        if len(Y.getbands()) == 1:
            Y = Y.convert("RGB")

        #X = self.downsample(Y)
        X = Y.filter(ImageFilter.BLUR)

        X = self.transform(X)
        Y = self.transform(Y)

        return X,Y

class DRRN(nn.Module):
    def __init__(self):
        super(DRRN, self).__init__()
        # 128 channels in hidden layers
        c = 32
        self.input = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, stride=1, padding=1, bias=False)
        self.output = nn.Conv2d(in_channels=c, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def init_weights(self):
        print(" ****************************")
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

    def forward(self, x):
        residual = x
        inputs = self.input(self.relu(x))
        out = inputs
        for _ in range(2): #25
            out = self.conv2(self.relu(self.conv1(self.relu(out))))
            out = torch.add(out, inputs)
            
        out = self.output(self.relu(out))
        out = torch.add(out, residual)
        return out
        
tensor_to_pil = ToPILImage()

def display_img(img1, img2):
    plt.subplot(1, 2, 1)
    plt.imshow(img1)

    plt.subplot(1, 2, 2)
    plt.imshow(img2)

    plt.show()

'''
dataloader = dataloader('./dataset/Set14/image_SRF_2')
plt.figure("ddd")
plt.ion()
plt.cla()
for X,Y in dataloader:
    plt.imshow(X)
    plt.pause(0.1)
plt.ioff()

iii = tensor_to_pil(Y_[0])
plt.imshow(iii)
plt.show()


'''

model1 = DRRN()
model2 = DRRN()
model3 = DRRN()


loss = nn.MSELoss()

optimizer1 = torch.optim.Adam(model1.parameters(), lr = 0.001, betas = (0.9, 0.999))
optimizer2 = torch.optim.Adam(model2.parameters(), lr = 0.001, betas = (0.9, 0.999))
optimizer3 = torch.optim.Adam(model3.parameters(), lr = 0.001, betas = (0.9, 0.999))

dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', 10, 2)
#dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2)
dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 2)
new_training = 1

def train():
    global model1
    global model2
    global model3

    if new_training == 1:
        print(" new training ...")
        model1.init_weights()
        model2.init_weights()
        model3.init_weights()
    else:
        print(" continue training ...")
        model1= torch.load("./model/srcnn_l1.m")
        model2= torch.load("./model/srcnn_l2.m")
        model3= torch.load("./model/srcnn_l3.m")


    running_loss = 0
    plt.ion()

    i = 0

    for X,Y in dataloader_train:
        #print(X.shape)

        '''
        Train
        '''
        model1.train()
        model2.train()
        model3.train()

        _,c,h,w = X.shape

        Y_1 = model1(X[:,0].reshape(-1,1,h,w))
        optimizer1.zero_grad()
        cost1 = loss(Y_1, Y[:,0].reshape(-1,1,h,w))
        cost1.backward()
        optimizer1.step()

        Y_2 = model2(X[:,1].reshape(-1,1,h,w))
        optimizer2.zero_grad()
        cost2 = loss(Y_2, Y[:,1].reshape(-1,1,h,w))
        cost2.backward()
        optimizer2.step()

        Y_3 = model3(X[:,2].reshape(-1,1,h,w))
        optimizer3.zero_grad()
        cost3 = loss(Y_3, Y[:,2].reshape(-1,1,h,w))
        cost3.backward()
        optimizer3.step()


        #running_loss += cost.item()
        print( "batch:{}, loss  is {:.4f} {:.4f} {:.4f} ".format(i, cost1, cost2, cost3) )

        i += 1
        if i%1 == 0:
            test(model1, model2, model3)

        '''
        Display images
        '''
        """
        Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)

        iii = tensor_to_pil(Y_[0])
        plt.subplot(1, 2, 1)
        plt.imshow(iii)
        
        iii = tensor_to_pil(X[0])
        plt.subplot(1, 2, 2)
        plt.imshow(iii)

        plt.show()
        plt.pause(0.1)
        """

    plt.ioff()
    torch.save(model1, "./model/srcnn_l1.m")
    torch.save(model2, "./model/srcnn_l2.m")
    torch.save(model3, "./model/srcnn_l3.m")

def test(model1, model2, model3):

    running_loss = 0
    #plt.ion()

    for X,Y in dataloader_test:

        #print(X.shape)
        model1.eval()
        model2.eval()
        model3.eval()

        _,c,h,w = X.shape

        Y_1 = model1(X[:,0].reshape(-1,1,h,w))
        Y_2 = model2(X[:,1].reshape(-1,1,h,w))
        Y_3 = model3(X[:,2].reshape(-1,1,h,w))

        print( c, h, w)
        Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)

        psnrv = psnr(Y_, Y)

        print(" Psnr = ", psnr(Y_,Y), "  psnr2 = ", psnr(X, Y) )

        '''
        Display images
        '''
        iii = tensor_to_pil(Y_[0])
        plt.subplot(1, 2, 1)
        plt.imshow(iii)
        
        iii = tensor_to_pil(X[0])
        plt.subplot(1, 2, 2)
        plt.imshow(iii)

        plt.show()
        plt.pause(0.1)

    #plt.ioff()



def run():
    model1= torch.load("./model/srcnn_l1.m")
    model2= torch.load("./model/srcnn_l2.m")
    model3= torch.load("./model/srcnn_l3.m")

    model1.eval()
    model2.eval()
    model3.eval()
    path = "./dataset/Set5/image_SRF_2"

    file_list = os.listdir(path)
    #plt.ion()


    # block size of LR images
    lh = opt['pic_h']
    lw = opt['pic_w']

    # the stride of moving blocks processing
    stride_h = lh - opt['cat_padding']  * 2
    stride_w = lw - opt['cat_padding']  * 2

    # the center sub block of model outputings
    y_top = (opt['cat_padding'])
    y_left = (opt['cat_padding'])
    y_down = (lh - opt['cat_padding'])
    y_right = (lw - opt['cat_padding'])

    for img_name in file_list:
        # get images
        img = Image.open(path+'/'+img_name)
        if len(img.getbands()) == 1:
            img = img.convert("RGB")
        c = 3

        org_img = img
        org_w, org_h = org_img.size
        dst_w, dst_h = org_w * opt['pic_up'], org_h * opt['pic_up']

        transform_up = transforms.Compose([
            #transforms.RandomResizedCrop(pad_h, scale=(2.0,2.0)),
            transforms.Resize( (dst_h, dst_w) ),
            #transforms.CenterCrop( (pad_h, pad_w) ),
            ])
        img = transform_up(img)

        # padding the edge, turn ( h X w ) to  ( uh X uw )
        pad_h = dst_h + opt['cat_padding'] * 2
        pad_w = dst_w + opt['cat_padding'] * 2
        pad_h_off = ((pad_h - 2*opt['cat_padding']) % stride_h)
        pad_w_off = ((pad_w - 2*opt['cat_padding']) % stride_w)

        if pad_h_off != 0:
            pad_h += stride_h - pad_h_off
        if pad_w_off != 0:
            pad_w += stride_w - pad_w_off
        # pad_h must == pad_w
        transform_pad = transforms.Compose([
            #transforms.RandomResizedCrop(pad_h, scale=(2.0,2.0)),
            transforms.CenterCrop( (pad_h, pad_w) ),
            ])
        img = transform_pad(img)
        w,h = img.size

        # srcnn : preprocessing for images. So set size to sr_h, sr_w.
        transform_mid = transforms.Compose([
            #transforms.CenterCrop( (sr_h, sr_w) ),
            transforms.ToTensor(),
            #transforms.Normalize(means, stds)
            ])
        transform_last = transforms.Compose([
            transforms.CenterCrop( (dst_h, dst_w) ),
            transforms.ToTensor(),
            #transforms.Normalize(means, stds), as the Y_ has been verse normalized.
            ])

        Y = torch.zeros( (1, 3, h, w) )
        img_ = transform_mid(img)

        top = 0
        down = top + lh
        while  down <= h:
            left = 0
            right = left + lw
            while right <= w:

                X_ = img_[:, top : down, left : right].detach()
                X_ = X_.reshape(-1, c, lh, lw)


                Y_1 = model1(X_[:,0].reshape(-1,1,h,w))
                Y_2 = model2(X_[:,1].reshape(-1,1,h,w))
                Y_3 = model3(X_[:,2].reshape(-1,1,h,w))

                Y_ = torch.stack( [Y_1[:,0], Y_2[:,0], Y_3[:,0] ], dim = 1)

                cent_top = (top + opt['cat_padding'])
                cent_left = (left + opt['cat_padding'])
                cent_down = (down - opt['cat_padding']) 
                cent_right = (right - opt['cat_padding'])

                Y[0,:, cent_top : cent_down, cent_left: cent_right] = Y_[0,:, y_top: y_down, y_left: y_right]

                # move patchs to right
                left += stride_w
                right = left + lw
            # move patchs to down
            top += stride_h
            down = top + lh


        Y = verse_normalize(Y, means, stds)
        Y = Y[0]
        Y = tensor_to_pil(Y)
        Y = transform_last(Y)
        Y = tensor_to_pil(Y)
        Y.save("./DRRN.png")
        display_img( Y,  org_img )

        plt.pause(1)
        plt.clf()


"""

For training run this command:
python face_cnn.py train

For testing fun this command:
python face_cnn.py test

"""
if __name__ == '__main__':
    args = sys.argv[1:]
    print( args, len(args))
    if (len(args) == 1) & (args[0] == "train"):
        train()
    elif (len(args) == 1) & (args[0] == "run"):
        run()
    else:
        test()


标签: python machine_learning DL

发表评论:

Powered by anycle 湘ICP备15001973号-1