[DL]ESRT: Transformer for Single Image Super
"""Author: Nicholas Xiao
Blog: log.anycle.com
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random
import copy
means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]
opt = {
"pic_c":32, # middle convolution layers is 32
"pic_h":96, # low responsity images' height is 48
"pic_w":96,
"pic_up":2, # upscale
"heads":8,
"vec_dim":512,
"N":6,
"x_vocab_len":0,
"y_vocab_len":0,
"sentence_len":80,
"batchs":1000,
"batch_size":10,
"pad":0,
"sof":1,
"eof":2,
"cat_padding": 5,
}
"""
# errors
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/1.0) - (img2/1.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)
"""
def psnr(img1, img2):
img1 = Variable(img1)
img2 = Variable(img2)
mse = ( (img1/255.0) - (img2/255.0) ) ** 2
mse = np.mean( np.array(mse) )
if mse < 1.0e-10:
return 100
return 10 * math.log10(1/math.sqrt(mse))
def rgb2ycbcr(rgb_image):
"""convert rgb into ycbcr"""
if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
return rgb_image
#raise ValueError("input image is not a rgb image")
rgb_image = rgb_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
shift_matrix = np.array([16, 128, 128])
ycbcr_image = np.zeros(shape=rgb_image.shape)
w, h, _ = rgb_image.shape
for i in range(w):
for j in range(h):
ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
return ycbcr_image
def ycbcr2rgb(ycbcr_image):
"""convert ycbcr into rgb"""
if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
return ycbcr_image
#raise ValueError("input image is not a rgb image")
ycbcr_image = ycbcr_image.astype(np.float32)
transform_matrix = np.array([[0.257, 0.564, 0.098],
[-0.148, -0.291, 0.439],
[0.439, -0.368, -0.071]])
transform_matrix_inv = np.linalg.inv(transform_matrix)
shift_matrix = np.array([16, 128, 128])
rgb_image = np.zeros(shape=ycbcr_image.shape)
w, h, _ = ycbcr_image.shape
for i in range(w):
for j in range(h):
rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
return rgb_image.astype(np.uint8)
class dataloader:
def __init__(self, path, batchs, batch_size, test= False, order = False):
self.test = test
self.batchs_cnt = 0
self.batchs = batchs
self.batch_size = batch_size
self.path = path
self.file_list = os.listdir(self.path)
if order:
self.file_list = sorted(self.file_list)
else:
random.shuffle(self.file_list)
self.file_number = len(self.file_list)
self.file_start = 0
self.file_stop = self.file_number - 1
self.file_start = int(np.floor(self.file_start))
self.file_stop = int(np.floor(self.file_stop))
self.file_idx = self.file_start
self.downsample = transforms.Compose([
transforms.Resize( (opt['pic_h'] // opt['pic_up'], opt['pic_w'] // opt['pic_up']), Image.BICUBIC),
#transforms.CenterCrop( pic_height),
#transforms.Scale( pic_height // 2, interpolation=Image.BICUBIC),
#transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
#transforms.RandomResizedCrop( int(pic_height/2), scale=(1,1) ),
#transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
self.transform = transforms.Compose([
#transforms.CenterCrop( pic_height),
transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),
#transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
self.center_down = transforms.Compose([
transforms.CenterCrop( (opt['pic_h'] // opt['pic_up'], opt['pic_w'] // opt['pic_up']) ),
#transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),
#transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
self.center = transforms.Compose([
transforms.CenterCrop( (opt['pic_h'], opt['pic_w'])),
#transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),
#transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
#transforms.Grayscale(num_output_channels = 3),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
def get_len(self):
return self.file_stop - self.file_start
def __iter__(self):
return self
def __next__(self):
if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
self.batchs_cnt = 0
raise StopIteration
self.batchs_cnt += 1
X = []
Y = []
for i in range( self.batch_size):
X_, Y_ = self._next()
X.append(X_)
Y.append(Y_)
X = torch.stack(X, 0)
Y = torch.stack(Y, 0)
return X, Y
def _next(self):
if self.file_idx >= self.file_stop:
self.file_idx = self.file_start
file_path = self.path + '/' + self.file_list[self.file_idx]
if self.test:
file_path_hr = self.path + '/' + self.file_list[self.file_idx+1]
self.file_idx += 2
X = Image.open(file_path)
Y = Image.open(file_path_hr)
if len(X.getbands()) == 1:
X = X.convert("RGB")
if len(Y.getbands()) == 1:
Y = Y.convert("RGB")
X = self.center_down(X)
Y = self.center(Y)
else:
self.file_idx += 1
Y = Image.open(file_path)
if len(Y.getbands()) == 1:
Y = Y.convert("RGB")
X = self.downsample(Y)
Y = self.transform(Y)
return X,Y
def get_clones(module, N):
return nn.ModuleList( [ copy.deepcopy(module) for i in range(N) ] )
# RU
class RU(nn.Module):
def __init__(self, c, lamda):
super().__init__()
mid_c = int(c//2)
self.lamda = nn.Parameter(torch.FloatTensor([lamda]))
self.reduction = nn.Conv2d(c, mid_c, kernel_size=1, stride=1, padding=0)
self.expansion = nn.Conv2d(mid_c, c, kernel_size=1, stride=1, padding=0)
def forward(self, x):
y = self.reduction(x)
y = self.expansion(y)
y = y * self.lamda
y = y + x
return y
# Adaptive Residual Feature Block
class ARFB(nn.Module):
def __init__(self, c, lamda_res, lamda_x, lamda_ru):
super().__init__()
self.ru = RU(c, lamda_ru)
self.conv1 = nn.Conv2d( 2*c, 2*c, kernel_size=1, stride=1, padding=0)
self.conv3 = nn.Conv2d( 2*c, c, kernel_size=3, stride=1, padding=1)
self.lamda_res = nn.Parameter(torch.FloatTensor([lamda_res]))
self.lamda_x = nn.Parameter(torch.FloatTensor([lamda_x]))
def forward(self, x):
y1 = self.ru(x)
y2 = self.ru(y1)
y2 = torch.cat( (y1, y2), dim=1)
y2 = self.conv1(y2)
y2 = self.conv3(y2)
y2 = y2 * self.lamda_res
y = x * self.lamda_x + y2
return y
# High-frequency Filtering Module
class HFM(nn.Module):
def __init__(self, c, k):
super().__init__()
self.avgPool = nn.AvgPool2d(kernel_size=k, stride=k, padding=0)
self.upsample = nn.Upsample(scale_factor=k, mode='nearest')
def forward(self, TL):
TA = self.avgPool(TL)
TU = self.upsample(TA)
y = TL - TU
return y
class CA(nn.Module):
def __init__(self, c, reduction=16):
super().__init__()
self.avgPool= nn.AdaptiveAvgPool2d(1)
self.convDu = nn.Sequential(
nn.Conv2d(c, c//reduction, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(c//reduction, c, kernel_size=1, stride=1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avgPool(x)
y = self.convDu(y)
return x * y
class HPB(nn.Module):
def __init__(self, c, k, lamda_res=0.5, lamda_x=0.5, lamda_ru=0.5):
super().__init__()
self.arfb = ARFB(c, lamda_res, lamda_x, lamda_ru)
self.hfm = HFM(c,k)
self.conv1 = nn.Conv2d(2*c, c, kernel_size=1, stride=1, padding=0)
self.ca = CA(c, 2)
def forward(self, x):
# main line
y = self.arfb(x)
y = self.hfm(y)
# half y
y2 = y / 2
for i in range(5):
y2 = self.arfb(y2)
y2 = y * 2
# main line
y = self.arfb(y)
y = torch.cat( (y, y2), dim=1)
y = self.conv1(y)
y = self.ca(y)
y = self.arfb(y)
y = x + y
return y
class LCB(nn.Module):
def __init__(self,c,k=2, hpb_num=3):
super().__init__()
self.hpb_num = hpb_num
self.layers = get_clones( HPB(c, k), hpb_num)
def forward(self, x):
for i in range(self.hpb_num):
x = self.layers[i](x)
return x
# attention
def attention(q,k,v,dk,mask=None,dropout=None):
m = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(dk)
if mask is not None:
mask = mask.unsqueeze(1)
m = m.masked_fill(mask == 0, -1e9)
#print(m)
m = F.softmax(m, dim = -1)
if dropout is not None:
m = dropout(m)
#print(m)
m = torch.matmul( m, v)
return m
# B*C*N
class EMHA(nn.Module):
def __init__(self, c, heads=8, s=4, dropout=0.1):
super().__init__()
self.s = s
self.c1 = c // 2
self.dk = self.c1 // heads
self.h = heads
self.dropout = nn.Dropout( dropout)
self.reduction = nn.Conv1d(c, self.c1, kernel_size=1, stride=1, padding=0)
self.expansion = nn.Conv1d(self.c1, c, kernel_size=1, stride=1, padding=0)
self.lq = nn.Linear(self.c1, self.c1)
self.lk = nn.Linear(self.c1, self.c1)
self.lv = nn.Linear(self.c1, self.c1)
def forward(self, x):
batch_size = x.size(0)
# reduce (B C N ) to (B C1 N)
y = self.reduction(x)
# swap (B C1 N) to (B N C1)
y = y.transpose(1,2)
# linear then to (B N M C1//M) , M = h, dk = c1//h
q = self.lq(y).view(batch_size, -1, self.h, self.dk)
k = self.lk(y).view(batch_size, -1, self.h, self.dk)
v = self.lv(y).view(batch_size, -1, self.h, self.dk)
# split fetatures to s(4) patchs
b,n,m,c_ = q.shape
q = q.view(b, self.s, -1, m, c_)
k = k.view(b, self.s, -1, m, c_)
v = v.view(b, self.s, -1, m, c_)
# swap (B S N_s M C1//M) to (B S M N_s C1//M)
q = q.transpose(2,3)
k = k.transpose(2,3)
v = v.transpose(2,3)
# swap (B S M N_s C1//M) to (S B M N_s C1//M)
q = q.transpose(0,1)
k = k.transpose(0,1)
v = v.transpose(0,1)
# get attention of (S B M N_s C1//M)
atn = []
for i in range( self.s ):
atn.append( attention(q[i],k[i],v[i], self.dk, None, self.dropout) )
# get attention of (B M N C1//M)
atn = torch.cat( atn, dim = 2 )
# cat to ( B N C1)
ret = atn.transpose(1,2).contiguous().view(batch_size, -1, self.c1)
# swap to ( B C1 N)
ret = ret.transpose(1,2)
# expand (B C1 N) to (B C N)
ret = self.expansion(ret)
return ret
class Norm(nn.Module):
def __init__(self, d_model, eps = 1e-6):
super().__init__()
self.size = d_model
self.alpha = nn.Parameter(torch.ones(self.size))
self.bias = nn.Parameter(torch.zeros(self.size))
self.eps = eps
def forward(self, x):
norm = self.alpha * (x - x.mean(dim = -1, keepdim=True)) / ( x.std(dim=-1, keepdim=True) + self.eps ) + self.bias
return norm
# Efficient Transformer
# (B C H W) to (B C N), N = d_model
class ET(nn.Module):
def __init__(self, c, d_model, heads):
super().__init__()
self.c = c
self.d_model = d_model
self.norm1 = Norm( self.d_model )
self.norm2 = Norm( self.d_model )
self.emha = EMHA(c, heads)
self.mlp = nn.Linear(self.d_model, self.d_model)
def forward(self, x):
y2 = x
y1 = y2
# main line
y2 = self.norm1(y2)
y2 = self.emha(y2)
# residual
y2 = y1 + y2
y1 = y2
# main line
y2 = self.norm2(y2)
y2 = self.mlp(y2)
# residual
y2 = y1 + y2
return y2
class LTB(nn.Module):
def __init__(self, c, h=int(opt['pic_h']/opt['pic_up']), w=int(opt['pic_w']/opt['pic_up']), k=3, heads=opt['heads'], et_num=1):
super().__init__()
self.et_num = et_num
self.unfold = nn.Unfold(kernel_size=k, dilation=1, stride=1, padding=1)
self.layers = get_clones( ET(c*k*k, h*w, heads), et_num)
self.fold = nn.Fold( (h,w), kernel_size=k, dilation=1, stride=1, padding=1)
def forward(self, x):
y = self.unfold(x)
for i in range(self.et_num):
y = self.layers[i](y)
y = self.fold(y)
return y
class ESRT(nn.Module):
def __init__(self, c, h, w, upscale):
super().__init__()
self.train_loops = 0
self.lr = 2e-4
self.c1 = int( c // math.pow(2, upscale) )
self.ltb = LTB(c)
self.lcb = LCB(c)
self.conv1 = nn.Conv2d( 3, c, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d( c, c, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d( self.c1, 3, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d( self.c1, 3, kernel_size=3, stride=1, padding=1)
self.up = nn.PixelShuffle(upscale)
self.upscale = upscale
def forward(self, x):
y1 = self.conv1(x)
# main line
y2 = self.lcb(y1)
y2 = self.ltb(y2)
y2 = self.conv2(y2)
y2 = self.up(y2)
y2 = self.conv3(y2)
# sub line
y1 = self.up(y1)
y1 = self.conv4(y1)
y2 = y2 + y1
return y2
def init_weights(self):
print(" init weights \n")
def tmp_test():
#rus = RU(4, 0.5)
a = [ [1,2,3,4],[5,6,7,8],[1,2,4,3],[4,4,4,4]]
a = [ a,a,a,a]
a = [a]
#a = torch.randn(1,4,4,4)
a = torch.tensor(a, dtype=torch.float)
#rus = ARFB(4, 0.5,0.5,0.5)
#rus = HFM(4, 2)
#rus = HPB(4, 2)
#rus = CA(4,2)
#rus = EMHA(4,2)
#a = a.view(1, 4, 16)
#rus = ET(4, 4, 4, 2)
#rus = LTB(4,4,4,3,2,6)
#rus = LCB(4,2,60)
rus = ESRT(4,4,4,1)
b = rus(a)
print(a,"\n",b)
tensor_to_pil = ToPILImage()
model = ESRT(opt["pic_c"],opt["pic_h"],opt["pic_w"],opt["pic_up"])
dataloader_train = dataloader('/home/nicholas/Documents/dataset/DIV2K_train_LR_unknown/X2', 10000000, 5)
dataloader_test = dataloader('./dataset/images', 1, 5, test=True)
if os.path.exists("./model/esrt.m"):
print( " model exists ... \n")
new_training = 0
else:
print( " model not exists, will create a new one. \n")
new_training = 1
def display_img(img1, img2):
plt.subplot(1, 2, 1)
plt.imshow(img1)
plt.subplot(1, 2, 2)
plt.imshow(img2)
plt.show()
def verse_normalize(img, means, stds):
means = torch.Tensor(means)
stds = torch.Tensor(stds)
for i in range(3):
img[:,i] = img[:,i] * stds[i] + means[i]
return img
def train():
global model
if new_training == 1:
print(" new training ...")
model.init_weights()
else:
print(" continue training ...")
model= torch.load("./model/esrt.m")
loss = nn.MSELoss()
running_loss = 0
plt.ion()
# every 200 patchs , update the learning rate.
lr_index = 0
optimizer = torch.optim.Adam(model.parameters(), lr = model.lr, betas = (0.9, 0.999))
for X,Y in dataloader_train:
lr_index += 1
if lr_index % 2000 is 0:
model.lr /= 2
print( " model lr = ", model.lr)
optimizer = torch.optim.Adam(model.parameters(), lr = model.lr, betas = (0.9, 0.999))
'''
Train
'''
model.train()
_,c,h,w = X.shape
#Y_1 = model(X[:,0].reshape(-1,1,h,w))
Y_ = model(X)
optimizer.zero_grad()
#cost1 = loss(Y_1, R[:,0].reshape(-1,1,h,w))
cost = loss(Y, Y_)
cost.backward()
optimizer.step()
#running_loss += cost.item()
print( "batch:{}, loss is {:.4f} ".format(model.train_loops, cost) )
model.train_loops += 1
if model.train_loops %10 == 0:
torch.save(model, "./model/esrt.m")
test()
plt.ioff()
def test():
global model
running_loss = 0
#plt.ion()
for X,Y in dataloader_test:
_, _, dst_h, dst_w = Y.shape
model.eval()
_,c,h,w = X.shape
Y_ = model(X)
'''
Display images
'''
OX = tensor_to_pil(X[0])
del X
plt.subplot(1, 2, 2)
plt.imshow(OX)
del OX
Y_ = verse_normalize( Y_[0], means, stds )
OY_ = tensor_to_pil(Y_)
del Y_
plt.subplot(1, 2, 1)
plt.imshow(OY_)
OY = verse_normalize(Y[0], means, stds)
OY = tensor_to_pil(OY)
transform_last = transforms.Compose([
transforms.CenterCrop( (dst_h, dst_w) ),
transforms.ToTensor(),
#transforms.Normalize(means, stds)
])
print(" Psnr = ", psnr(transform_last(OY_), transform_last(OY)) )
del OY_
plt.show()
plt.pause(0.1)
#plt.ioff()
def run(path):
model= torch.load("./model/esrt.m")
model.eval()
if path is None:
path = "./images"
file_list = os.listdir(path)
#plt.ion()
# block size of SR images
lh = opt['pic_h'] // opt['pic_up']
lw = opt['pic_w'] // opt['pic_up']
# the stride of moving blocks processing
stride_h = lh - opt['cat_padding'] * 2
stride_w = lw - opt['cat_padding'] * 2
# the center sub block of model outputings
y_top = (opt['cat_padding']) * opt['pic_up']
y_left = (opt['cat_padding']) * opt['pic_up']
y_down = (lh - opt['cat_padding']) * opt['pic_up']
y_right = (lw - opt['cat_padding']) * opt['pic_up']
for img_name in file_list:
# get images
img = Image.open(path+'/'+img_name)
if len(img.getbands()) == 1:
img = img.convert("RGB")
c = 3
w,h = img.size
org_img = img
org_w, org_h = w,h
dst_w, dst_h = w * opt['pic_up'], h * opt['pic_up']
# padding the edge, turn ( h X w ) to ( uh X uw )
pad_h = h + opt['cat_padding'] * 2
pad_w = w + opt['cat_padding'] * 2
pad_h_off = ((pad_h - 2*opt['cat_padding']) % stride_h)
pad_w_off = ((pad_w - 2*opt['cat_padding']) % stride_w)
if pad_h_off != 0:
pad_h += stride_h - pad_h_off
if pad_w_off != 0:
pad_w += stride_w - pad_w_off
transform_pad = transforms.Compose([
transforms.CenterCrop( (pad_h, pad_w) ),
])
img = transform_pad(img)
w,h = img.size
# size of SR images
sr_h = h * opt['pic_up']
sr_w = w * opt['pic_up']
transform_mid = transforms.Compose([
transforms.CenterCrop( (h, w) ),
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
transform_last = transforms.Compose([
transforms.CenterCrop( (dst_h, dst_w) ),
transforms.ToTensor(),
#transforms.Normalize(means, stds), as the Y_ has been verse normalized.
])
Y = torch.zeros( (1, 3, sr_h, sr_w) )
img_ = transform_mid(img)
top = 0
down = top + lh
while down <= h:
left = 0
right = left + lw
while right <= w:
X_ = img_[:, top : down, left : right].detach()
X_ = X_.reshape(-1, c, lh, lw)
Y_ = model(X_)
cent_top = (top + opt['cat_padding']) * opt['pic_up']
cent_left = (left + opt['cat_padding']) * opt['pic_up']
cent_down = (down - opt['cat_padding']) * opt['pic_up']
cent_right = (right - opt['cat_padding']) * opt['pic_up']
Y[0, :, cent_top : cent_down, cent_left: cent_right] = Y_[:, :, y_top: y_down, y_left: y_right]
# move patchs to right
left += stride_w
right = left + lw
# move patchs to down
top += stride_h
down = top + lh
Y = verse_normalize(Y, means, stds)
Y = Y[0]
Y = tensor_to_pil(Y)
Y = transform_last(Y)
print(" Psnr = ", psnr(Y, transform_last(org_img)) )
Y = tensor_to_pil(Y)
Y.save("./ESRT.png")
display_img( Y, org_img)
plt.pause(1)
plt.clf()
"""
For training run this command:
python esrt.py train
For running run this command:
python esrt.py run
or
python esrt.py run YOUR_IMAGES_DIRECTORY
"""
if __name__ == '__main__':
args = sys.argv[1:]
print( args, len(args))
if (len(args) == 1) & (args[0] == "train"):
train()
elif (len(args) == 1) & (args[0] == "run"):
run()
elif (len(args) == 2) & (args[0] == "run"):
run( args[1] )
else:
tmp_test()
标签: python machine_learning DL
日历
最新微语
- 有的时候,会站在分叉路口,不知道向左还是右
2023-12-26 15:34
- 繁花乱开,鸟雀逐风。心自宁静,纷扰不闻。
2023-03-14 09:56
- 对于不可控的事,我们保持乐观,对于可控的事情,我们保持谨慎。
2023-02-09 11:03
- 小时候,
暑假意味着无忧无虑地玩很长一段时间,
节假意味着好吃好喝还有很多长期不见的小朋友来玩...
长大后,
这是女儿第一个暑假,
一个半月...
2022-07-11 08:54
- Watching the autumn leaves falling as you grow older together
2018-10-25 09:45
分类
最新评论
- Goonog
i get it now :) - 萧
@Fluzak:The web host... - Fluzak
Nice blog here! Also... - Albertarive
In my opinion you co... - ChesterHep
What does it plan? - ChesterHep
No, opposite. - mojoheadz
Everything is OK!... - Josephmaigh
I just want to say t... - ChesterHep
What good topic - AnthonyBub
Certainly, never it ...
发表评论: