[DL]SRGAN
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random
from torchvision.models import *
means = np.array([0.485, 0.456, 0.406])
stds = np.array([0.229, 0.224, 0.225])
channels = 3
pic_height = 128
pic_width = 128
upscale_factor = 2
hr_shape = (pic_height, pic_width)
batchs = 1000
batch_size = 4
learning_rate = 0.0001
beats_l = 0.9
beats_r = 0.999
"""
# errors
def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/1.0) - (img2/1.0)   ) ** 2 
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(255.0**2/mse)
"""
def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/255.0) - (img2/255.0)   ) ** 2 
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(1/math.sqrt(mse))
def rgb2ycbcr(rgb_image):
    """convert rgb into ycbcr"""
    if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
        return rgb_image
        #raise ValueError("input image is not a rgb image")
    rgb_image = rgb_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    shift_matrix = np.array([16, 128, 128])
    ycbcr_image = np.zeros(shape=rgb_image.shape)
    w, h, _ = rgb_image.shape
    for i in range(w):
        for j in range(h):
            ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
    return ycbcr_image
def ycbcr2rgb(ycbcr_image):
    """convert ycbcr into rgb"""
    if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
        return ycbcr_image
        #raise ValueError("input image is not a rgb image")
    ycbcr_image = ycbcr_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    transform_matrix_inv = np.linalg.inv(transform_matrix)
    shift_matrix = np.array([16, 128, 128])
    rgb_image = np.zeros(shape=ycbcr_image.shape)
    w, h, _ = ycbcr_image.shape
    for i in range(w):
        for j in range(h):
            rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
    return rgb_image.astype(np.uint8)
def verse_normalize(img, means, stds):
    means = torch.Tensor(means)
    stds = torch.Tensor(stds)
    for i in range(3):
        img[:,i] = img[:,i] * stds[i] + means[i]
    return img
class dataloader:
    def __init__(self, path, batchs,  batch_size, test= False, test_offset = 0.9, order = False):
        self.test = test
        self.test_offset = test_offset
        self.batchs_cnt = 0
        self.batchs = batchs
        self.batch_size = batch_size
        self.path = path
        self.file_list = os.listdir(self.path)
        if order:
            self.file_list = sorted(self.file_list)
        else:
            random.shuffle(self.file_list)
        self.file_number = len(self.file_list)
        self.file_start = 0
        self.file_stop = self.file_number * self.test_offset - 1
        if self.test:
            self.file_start = self.file_number * self.test_offset
            self.file_stop = self.file_number - 1
        self.file_start = int(np.floor(self.file_start))
        self.file_stop = int(np.floor(self.file_stop))
        self.file_idx = self.file_start
        self.downsample = transforms.Compose([
            transforms.Resize( (pic_height // upscale_factor, pic_width // upscale_factor), Image.BICUBIC),
            #transforms.CenterCrop( pic_height),
            #transforms.Scale( pic_height // 2, interpolation=Image.BICUBIC),
            #transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
            #transforms.RandomResizedCrop( int(pic_height/2), scale=(1,1) ),
            #transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])
        self.transform = transforms.Compose([
            #transforms.CenterCrop( pic_height),
            transforms.Resize( (pic_height,pic_width), Image.BILINEAR),
            #transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])
    def get_len(self):
        return self.file_stop  - self.file_start
    def __iter__(self):
        return self
    def __next__(self):
        if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
            self.batchs_cnt = 0
            raise StopIteration
        self.batchs_cnt += 1
        X = []
        Y = []
        for i in range( self.batch_size):
            X_, Y_ = self._next()
            X.append(X_)
            Y.append(Y_)
        X = torch.stack(X, 0)
        Y = torch.stack(Y, 0)
        return X, Y
    def _next(self):
        if self.file_idx >= self.file_stop:
            self.file_idx = self.file_start
        file_path = self.path + '/' + self.file_list[self.file_idx]
        self.file_idx += 1
        Y = Image.open(file_path)
        #print( "Y:", file_path, " (", Y.size )
        if len(Y.getbands()) == 1:
            Y = Y.convert("RGB")
        Y_ = self.transform(Y)
        X_ = self.downsample(Y)
        del Y
        return X_,Y_
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()   # in_features: 64
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
        )
        
    def forward(self, x):
        return x + self.conv_block(x)
class GeneratorResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(GeneratorResNet, self).__init__()
        
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
        res_blocks = []
        
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))
        upsampling = []
        for out_features in range(1): #2
            upsampling += [
                # nn.Upsample(scale_factor=2),
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)
        self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
    def forward(self, x):
        out1 = self.conv1(x)
        out2 = self.res_blocks(out1)
        out3 = self.conv2(out2)
        out4 = torch.add(out1, out3)
        out5 = self.upsampling(out4)
        out = self.conv3(out5)
        return out
class Discriminator(nn.Module):
    def __init__(self, input_shape): # input_shape: (3, 128, 128)
        super(Discriminator, self).__init__()
        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4) 
        self.output_shape = (1, patch_h, patch_w) # patch_h: 8 patch_w: 8
        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]): 
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters
        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
        self.model = nn.Sequential(*layers)
    def forward(self, img):
        return self.model(img)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True) 
        # the first 18 layers of vgg16
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])
    def forward(self, img):
        return self.feature_extractor(img)
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(channels, *hr_shape))
feature_extractor = FeatureExtractor()
print(generator)
print(discriminator)
feature_extractor.eval()
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beats_l, beats_r))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beats_l, beats_r))
tensor_to_pil = ToPILImage()
'''
dataloader = dataloader('./dataset/Set14/image_SRF_2')
plt.figure("ddd")
plt.ion()
plt.cla()
for X,Y in dataloader:
    plt.imshow(X)
    plt.pause(0.1)
plt.ioff()
iii = tensor_to_pil(Y_[0])
plt.imshow(iii)
plt.show()
'''
dataloader_train = dataloader('/home/nicholas/Documents/dataset/imageNet', batchs, batch_size)
#dataloader_train = dataloader('./dataset/Set14/image_SRF_3', 10, 4)
#dataloader_test = dataloader('/home/nicholas/Documents/dataset/imageNet', 1, 2)
#dataloader_test = dataloader('./dataset/Set14/image_SRF_3', 1, 4)
dataloader_test = dataloader('./images', 0, 3)
new_training = 0
Tensor = torch.Tensor
def train():
    try:
        generator.load_state_dict(torch.load("./model/generator.m"))
        discriminator.load_state_dict(torch.load("./model/discriminator.m"))
    except:
        print(" ************** Model config files not accessable !!! ************")
    running_loss = 0
    plt.ion()
    i = 0
    for X,Y in dataloader_train:
        '''
        Train
        '''
        valid = Variable( Tensor(np.ones( (X.size(0), *discriminator.output_shape) ) ), requires_grad=False)
        fake = Variable( Tensor(np.zeros( (X.size(0), *discriminator.output_shape) ) ), requires_grad=False)
        optimizer_G.zero_grad()
        gen_Y = generator( X )
        loss_GAN = criterion_GAN( discriminator( gen_Y ), valid)
        features_gen_Y = feature_extractor( gen_Y )
        features_Y = feature_extractor( Y )
        loss_content = criterion_content( features_gen_Y, features_Y.detach())
        loss_G = loss_content + 1e-3 * loss_GAN
        loss_G.backward()
        optimizer_G.step()
        optimizer_D.zero_grad()
        loss_valid = criterion_GAN( discriminator( Y ), valid)
        loss_fake = criterion_GAN( discriminator( gen_Y.detach() ), fake)
        loss_D = ( loss_valid + loss_fake ) / 2
        loss_D.backward()
        optimizer_D.step()
        print( "Epoch %d/%d: D loss: %f, G loss: %f"
                % ( i, batchs, loss_D.item(), loss_G.item() ) + '\n')
        #display_img( gen_Y[0], Y[0], X[0])
        i += 1
        if i%5 == 0:
            test()
            torch.save(generator.state_dict(), "./model/generator.m")
            torch.save(discriminator.state_dict(), "./model/discriminator.m")
    plt.ioff()
def display_img(img1, img2, img3 = None):
    n = 3
    if img3 is None:
        n = 2
    plt.subplot(1, n, 1)
    plt.imshow(img1)
    plt.subplot(1, n, 2)
    plt.imshow(img2)
    if img3 is not None:
        plt.subplot(1, n, 3)
        plt.imshow(img3)
    plt.show()
def test():
    running_loss = 0
    #for X,Y in dataloader_test:
    for i in range(1):
        X,Y = next(dataloader_test)
        print(X.shape)
        _,lc,lh,lw = X.shape
        _,hc,hh,hw = Y.shape
        Y_ = generator( X )
        print(" Psnr = ", psnr(Y_,Y)  )
        Y_ = verse_normalize(Y_, means, stds)
        Y = verse_normalize(Y, means, stds)
        X = verse_normalize(X, means, stds)
        '''
        Display images
        '''
        display_img( tensor_to_pil(Y_[0]), tensor_to_pil(Y[0]), tensor_to_pil(X[0]) )
       
        plt.pause(0.1)
def run():
    generator.load_state_dict(torch.load("./model/generator.m"))
    discriminator.load_state_dict(torch.load("./model/discriminator.m"))
    generator.eval()
    discriminator.eval()
    path = "./dataset/Set5/image_SRF_2"
    file_list = os.listdir(path)
    #plt.ion()
 
    for img_name in file_list:
        img = Image.open(path+'/'+img_name)
        if len(img.getbands()) == 1:
            img = img.convert("RGB")
        c = 3
        # turn ( h X w ) to  ( uh X uw )
        w, h = img.size
        uh = h * upscale_factor
        uw = w * upscale_factor
        # block size
        lh = pic_height // upscale_factor
        lw = pic_width // upscale_factor
        # blocks number
        rows = math.ceil( h / lh )
        cols = math.ceil( w / lw )
        # tmp img size 
        h_mid = rows * lh
        w_mid = cols * lw
        uh_mid = h_mid * upscale_factor
        uw_mid = w_mid * upscale_factor
        transform_mid = transforms.Compose([
            transforms.CenterCrop( (h_mid, w_mid) ),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])
        transform_last = transforms.Compose([
            transforms.CenterCrop( (uh, uw) ),
            transforms.ToTensor()
            ])
        Y = torch.zeros( (1, 3, uh_mid, uw_mid) )
        img_ = transform_mid(img)
        row = 0
        col = 0
        for row in range(rows):
            for col in range(cols):
                X_ = img_[:, row * lh : (row + 1) * lh, col * lw : (col + 1) * lw].detach()
                X_ = X_.reshape(-1, c, lh, lw)
                Y_ = generator(X_)
                Y[0, :, row * pic_height: (row+1) * pic_height, col * pic_width: (col+1) * pic_width] = Y_
        Y = verse_normalize(Y, means, stds)
        Y = Y[0]
        Y = tensor_to_pil(Y)
        Y = transform_last(Y)
        Y = tensor_to_pil(Y)
        Y.save("./SRGAN.png")
        display_img( Y, img)
        plt.pause(1)
        plt.clf()
    #plt.ioff()
"""
For training run this command:
python face_cnn.py train
For testing fun this command:
python face_cnn.py test
"""
if __name__ == '__main__':
    args = sys.argv[1:]
    print( args, len(args))
    if (len(args) == 1) & (args[0] == "train"):
        train()
    elif (len(args) == 1) & (args[0] == "run"):
        run()
    else:
        test()
	标签: python machine_learning DL
日历
最新微语
- 有的时候,会站在分叉路口,不知道向左还是右
2023-12-26 15:34
 - 繁花乱开,鸟雀逐风。心自宁静,纷扰不闻。
2023-03-14 09:56
 - 对于不可控的事,我们保持乐观,对于可控的事情,我们保持谨慎。
2023-02-09 11:03
 - 小时候,
暑假意味着无忧无虑地玩很长一段时间,
节假意味着好吃好喝还有很多长期不见的小朋友来玩...
长大后,
这是女儿第一个暑假,
一个半月...
2022-07-11 08:54
 - Watching the autumn leaves falling as you grow older together
2018-10-25 09:45
 
分类
最新评论
- Goonog	
i get it now :) - 萧	
@Fluzak:The web host... - Fluzak	
Nice blog here! Also... - Albertarive	
In my opinion you co... - ChesterHep	
What does it plan? - ChesterHep	
No, opposite. - mojoheadz	
Everything is OK!... - Josephmaigh	
I just want to say t... - ChesterHep	
What good topic - AnthonyBub	
Certainly, never it ... 
发表评论: