苗火 Nicholas
[DL]ESRT: Transformer for Single Image Super
2022-5-31 萧
"""

Author: Nicholas Xiao

Blog: log.anycle.com

"""

import os

import sys

import torch

import torch.nn as nn

import torch.nn.functional as F

from PIL import Image,ImageFilter

import matplotlib.pyplot as plt

import numpy as np

import torchvision

from torchvision import transforms

from torchvision.transforms import ToTensor

from torchvision.transforms import ToPILImage

from torch.autograd import Variable

import math

import random

import copy



means = [0.485, 0.456, 0.406]

stds = [ 0.229, 0.224, 0.225]





opt = {

    "pic_c":32, # middle convolution layers is 32

    "pic_h":96, # low responsity images' height is 48

    "pic_w":96,

    "pic_up":2, # upscale

    "heads":8,

    "vec_dim":512,

    "N":6,

    "x_vocab_len":0,

    "y_vocab_len":0,

    "sentence_len":80,

    "batchs":1000,

    "batch_size":10,

    "pad":0,

    "sof":1,

    "eof":2,

    "cat_padding": 5,

}







"""

# errors

def psnr(img1, img2):

    img1 = Variable(img1)

    img2 = Variable(img2)

    mse = (  (img1/1.0) - (img2/1.0)   ) ** 2

    mse = np.mean( np.array(mse) )

    if mse < 1.0e-10:

        return 100

    return 10 * math.log10(255.0**2/mse)

"""



def psnr(img1, img2):

    img1 = Variable(img1)

    img2 = Variable(img2)

    mse = (  (img1/255.0) - (img2/255.0)   ) ** 2

    mse = np.mean( np.array(mse) )

    if mse < 1.0e-10:

        return 100

    return 10 * math.log10(1/math.sqrt(mse))





def rgb2ycbcr(rgb_image):

    """convert rgb into ycbcr"""

    if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:

        return rgb_image

        #raise ValueError("input image is not a rgb image")

    rgb_image = rgb_image.astype(np.float32)

    transform_matrix = np.array([[0.257, 0.564, 0.098],

                                 [-0.148, -0.291, 0.439],

                                 [0.439, -0.368, -0.071]])

    shift_matrix = np.array([16, 128, 128])

    ycbcr_image = np.zeros(shape=rgb_image.shape)

    w, h, _ = rgb_image.shape

    for i in range(w):

        for j in range(h):

            ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix

    return ycbcr_image



def ycbcr2rgb(ycbcr_image):

    """convert ycbcr into rgb"""

    if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:

        return ycbcr_image

        #raise ValueError("input image is not a rgb image")

    ycbcr_image = ycbcr_image.astype(np.float32)

    transform_matrix = np.array([[0.257, 0.564, 0.098],

                                 [-0.148, -0.291, 0.439],

                                 [0.439, -0.368, -0.071]])

    transform_matrix_inv = np.linalg.inv(transform_matrix)

    shift_matrix = np.array([16, 128, 128])

    rgb_image = np.zeros(shape=ycbcr_image.shape)

    w, h, _ = ycbcr_image.shape

    for i in range(w):

        for j in range(h):

            rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)

    return rgb_image.astype(np.uint8)



class dataloader:

    def __init__(self, path, batchs,  batch_size, test= False, order = False):

        self.test = test

        self.batchs_cnt = 0

        self.batchs = batchs

        self.batch_size = batch_size

        self.path = path

        self.file_list = os.listdir(self.path)

        if order:

            self.file_list = sorted(self.file_list)

        else:

            random.shuffle(self.file_list)

        self.file_number = len(self.file_list)



        self.file_start = 0

        self.file_stop = self.file_number  - 1



        self.file_start = int(np.floor(self.file_start))

        self.file_stop = int(np.floor(self.file_stop))



        self.file_idx = self.file_start



        self.downsample = transforms.Compose([

            transforms.Resize( (opt['pic_h'] // opt['pic_up'], opt['pic_w'] // opt['pic_up']), Image.BICUBIC),

            #transforms.CenterCrop( pic_height),

            #transforms.Scale( pic_height // 2, interpolation=Image.BICUBIC),

            #transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),

            #transforms.RandomResizedCrop( int(pic_height/2), scale=(1,1) ),

            #transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),

            #transforms.Grayscale(num_output_channels = 3),

            transforms.ToTensor(),

            transforms.Normalize(means, stds)

            ])



        self.transform = transforms.Compose([

            #transforms.CenterCrop( pic_height),

            transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),

            #transforms.RandomResizedCrop( pic_height, scale=(1,1) ),

            #transforms.Grayscale(num_output_channels = 3),

            transforms.ToTensor(),

            transforms.Normalize(means, stds)

            ])



        self.center_down = transforms.Compose([

            transforms.CenterCrop( (opt['pic_h'] // opt['pic_up'], opt['pic_w'] // opt['pic_up']) ),

            #transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),

            #transforms.RandomResizedCrop( pic_height, scale=(1,1) ),

            #transforms.Grayscale(num_output_channels = 3),

            transforms.ToTensor(),

            transforms.Normalize(means, stds)

            ])







        self.center = transforms.Compose([

            transforms.CenterCrop( (opt['pic_h'], opt['pic_w'])),

            #transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),

            #transforms.RandomResizedCrop( pic_height, scale=(1,1) ),

            #transforms.Grayscale(num_output_channels = 3),

            transforms.ToTensor(),

            transforms.Normalize(means, stds)

            ])



    def get_len(self):

        return self.file_stop  - self.file_start



    def __iter__(self):

        return self



    def __next__(self):

        if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):

            self.batchs_cnt = 0

            raise StopIteration

        self.batchs_cnt += 1



        X = []

        Y = []

        for i in range( self.batch_size):

            X_, Y_ = self._next()

            X.append(X_)

            Y.append(Y_)



        X = torch.stack(X, 0)

        Y = torch.stack(Y, 0)

        return X, Y



    def _next(self):

        if self.file_idx >= self.file_stop:

            self.file_idx = self.file_start



        file_path = self.path + '/' + self.file_list[self.file_idx]



        if self.test:

            file_path_hr = self.path + '/' + self.file_list[self.file_idx+1]



            self.file_idx += 2



            X = Image.open(file_path)

            Y = Image.open(file_path_hr)



            if len(X.getbands()) == 1:

                X = X.convert("RGB")

            if len(Y.getbands()) == 1:

                Y = Y.convert("RGB")



            X = self.center_down(X)

            Y = self.center(Y)



        else:

            self.file_idx += 1



            Y = Image.open(file_path)



            if len(Y.getbands()) == 1:

                Y = Y.convert("RGB")



            X = self.downsample(Y)

            Y = self.transform(Y)



        return X,Y





def get_clones(module, N):

    return nn.ModuleList( [ copy.deepcopy(module) for i in range(N) ] )



# RU

class RU(nn.Module):

    def __init__(self, c, lamda):

        super().__init__()

        mid_c = int(c//2)

        self.lamda = nn.Parameter(torch.FloatTensor([lamda]))

        self.reduction = nn.Conv2d(c, mid_c, kernel_size=1, stride=1, padding=0)

        self.expansion = nn.Conv2d(mid_c, c, kernel_size=1, stride=1, padding=0)

    

    def forward(self, x):

        y = self.reduction(x)

        y = self.expansion(y)

        y = y * self.lamda

        y = y + x

        return y



# Adaptive Residual Feature Block

class ARFB(nn.Module):

    def __init__(self, c, lamda_res, lamda_x, lamda_ru):

        super().__init__()

        self.ru = RU(c, lamda_ru)

        self.conv1 = nn.Conv2d( 2*c, 2*c, kernel_size=1, stride=1, padding=0)

        self.conv3 = nn.Conv2d( 2*c, c, kernel_size=3, stride=1, padding=1)

        self.lamda_res = nn.Parameter(torch.FloatTensor([lamda_res]))

        self.lamda_x = nn.Parameter(torch.FloatTensor([lamda_x]))



    def forward(self, x):

        y1 = self.ru(x)

        y2 = self.ru(y1)

        y2 = torch.cat( (y1, y2), dim=1)

        y2 = self.conv1(y2)

        y2 = self.conv3(y2)

        y2 = y2 * self.lamda_res

        y = x * self.lamda_x + y2

        return y



# High-frequency Filtering Module

class HFM(nn.Module):

    def __init__(self, c, k):

        super().__init__()

        self.avgPool = nn.AvgPool2d(kernel_size=k, stride=k, padding=0)

        self.upsample = nn.Upsample(scale_factor=k, mode='nearest')



    def forward(self, TL):

        TA = self.avgPool(TL)

        TU = self.upsample(TA)

        y = TL - TU

        return y



class CA(nn.Module):

    def __init__(self, c, reduction=16):

        super().__init__()

        self.avgPool= nn.AdaptiveAvgPool2d(1)

        self.convDu = nn.Sequential(

                nn.Conv2d(c, c//reduction, kernel_size=1, stride=1, padding=0, bias=True),

                nn.ReLU(inplace=True),

                nn.Conv2d(c//reduction, c, kernel_size=1, stride=1, padding=0, bias=True),

                nn.Sigmoid()

            )



    def forward(self, x):

        y = self.avgPool(x)

        y = self.convDu(y)

        return x * y



class HPB(nn.Module):

    def __init__(self, c, k, lamda_res=0.5, lamda_x=0.5, lamda_ru=0.5):

        super().__init__()

        self.arfb = ARFB(c, lamda_res, lamda_x, lamda_ru)

        self.hfm = HFM(c,k)

        self.conv1 = nn.Conv2d(2*c, c, kernel_size=1, stride=1, padding=0)

        self.ca = CA(c, 2)

    def forward(self, x):

        # main line

        y = self.arfb(x)

        y = self.hfm(y)

        

        # half y

        y2 = y / 2

        for i in range(5):

            y2 = self.arfb(y2)

        y2 = y * 2



        # main line

        y = self.arfb(y)

        y = torch.cat( (y, y2), dim=1)

        y = self.conv1(y)

        y = self.ca(y)

        y = self.arfb(y)

        y = x + y



        return y



class LCB(nn.Module):

    def __init__(self,c,k=2, hpb_num=3):

        super().__init__()

        self.hpb_num = hpb_num

        self.layers = get_clones( HPB(c, k), hpb_num)



    def forward(self, x):

        for i in range(self.hpb_num):

            x = self.layers[i](x)

        return x





# attention

def attention(q,k,v,dk,mask=None,dropout=None):

    m = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(dk)



    if mask is not None:

        mask = mask.unsqueeze(1)

        m = m.masked_fill(mask == 0, -1e9)



    #print(m)

    m = F.softmax(m, dim = -1)



    if dropout is not None:

        m = dropout(m)



    #print(m)

    m = torch.matmul( m, v)



    return m



# B*C*N

class EMHA(nn.Module):

    def __init__(self, c, heads=8, s=4, dropout=0.1):

        super().__init__()

        self.s = s

        self.c1 = c // 2

        self.dk = self.c1 // heads

        self.h = heads

        self.dropout = nn.Dropout( dropout)



        self.reduction = nn.Conv1d(c, self.c1, kernel_size=1, stride=1, padding=0)

        self.expansion = nn.Conv1d(self.c1, c, kernel_size=1, stride=1, padding=0)

        self.lq = nn.Linear(self.c1, self.c1)

        self.lk = nn.Linear(self.c1, self.c1)

        self.lv = nn.Linear(self.c1, self.c1)



    def forward(self, x):

        batch_size = x.size(0)



        # reduce (B C N ) to (B C1 N)

        y = self.reduction(x)

        # swap (B C1 N)  to (B N C1)

        y = y.transpose(1,2)

        # linear then to (B N M C1//M) , M = h, dk = c1//h

        q = self.lq(y).view(batch_size, -1, self.h, self.dk)

        k = self.lk(y).view(batch_size, -1, self.h, self.dk)

        v = self.lv(y).view(batch_size, -1, self.h, self.dk)



        # split fetatures to s(4) patchs

        b,n,m,c_ = q.shape

        q = q.view(b, self.s, -1, m, c_)

        k = k.view(b, self.s, -1, m, c_)

        v = v.view(b, self.s, -1, m, c_)



        # swap (B S N_s M C1//M) to (B S M N_s C1//M)

        q = q.transpose(2,3)

        k = k.transpose(2,3)

        v = v.transpose(2,3)



        # swap (B S M N_s C1//M) to (S B M N_s C1//M)

        q = q.transpose(0,1)

        k = k.transpose(0,1)

        v = v.transpose(0,1)



        # get attention of (S B M N_s C1//M)

        atn = []

        for i in range( self.s ):

            atn.append( attention(q[i],k[i],v[i], self.dk, None, self.dropout) )

        # get attention of (B M N C1//M)

        atn = torch.cat( atn, dim = 2 )



        # cat to ( B N C1)

        ret = atn.transpose(1,2).contiguous().view(batch_size, -1, self.c1)

        # swap to ( B C1 N)

        ret = ret.transpose(1,2)

        # expand (B C1 N) to (B C N)

        ret = self.expansion(ret)



        return ret



class Norm(nn.Module):

    def __init__(self, d_model, eps = 1e-6):

        super().__init__()

        self.size = d_model

        self.alpha = nn.Parameter(torch.ones(self.size))

        self.bias = nn.Parameter(torch.zeros(self.size))

        self.eps = eps

    def forward(self, x):

        norm = self.alpha * (x - x.mean(dim = -1, keepdim=True)) / ( x.std(dim=-1, keepdim=True) + self.eps ) + self.bias

        return norm



# Efficient Transformer

# (B C H W)  to (B C N), N = d_model

class ET(nn.Module):

    def __init__(self, c, d_model, heads):

        super().__init__()

        self.c = c

        self.d_model = d_model

        self.norm1 = Norm( self.d_model )

        self.norm2 = Norm( self.d_model )

        self.emha = EMHA(c, heads)

        self.mlp = nn.Linear(self.d_model, self.d_model)



    def forward(self, x):

        y2 = x

        y1 = y2



        # main line

        y2 = self.norm1(y2)

        y2 = self.emha(y2)

        

        # residual

        y2 = y1 + y2

        y1 = y2



        # main line

        y2 = self.norm2(y2)

        y2 = self.mlp(y2)



        # residual

        y2 = y1 + y2



        return y2



class LTB(nn.Module):

    def __init__(self, c, h=int(opt['pic_h']/opt['pic_up']), w=int(opt['pic_w']/opt['pic_up']), k=3, heads=opt['heads'], et_num=1):

        super().__init__()

        self.et_num = et_num

        self.unfold = nn.Unfold(kernel_size=k, dilation=1, stride=1, padding=1)

        self.layers = get_clones( ET(c*k*k, h*w, heads), et_num)

        self.fold = nn.Fold( (h,w), kernel_size=k, dilation=1, stride=1, padding=1)



    def forward(self, x):

        y = self.unfold(x)

        for i in range(self.et_num):

            y = self.layers[i](y)



        y = self.fold(y)



        return y

        

class ESRT(nn.Module):

    def __init__(self, c, h, w, upscale):

        super().__init__()

        self.train_loops = 0

        self.lr = 2e-4

        self.c1 = int( c // math.pow(2, upscale) )

        self.ltb = LTB(c)

        self.lcb = LCB(c)

        self.conv1 = nn.Conv2d( 3, c, kernel_size=3, stride=1, padding=1)

        self.conv2 = nn.Conv2d( c, c, kernel_size=3, stride=1, padding=1)

        self.conv3 = nn.Conv2d( self.c1, 3, kernel_size=3, stride=1, padding=1)

        self.conv4 = nn.Conv2d( self.c1, 3, kernel_size=3, stride=1, padding=1)

        self.up = nn.PixelShuffle(upscale)

        self.upscale = upscale



    def forward(self, x):

        y1 = self.conv1(x)



        # main line

        y2 = self.lcb(y1)

        y2 = self.ltb(y2)

        y2 = self.conv2(y2)



        y2 = self.up(y2)

        y2 = self.conv3(y2)



        # sub line

        y1 = self.up(y1)

        y1 = self.conv4(y1)



        y2 = y2 + y1

        return y2



    def init_weights(self):



        print(" init weights \n")





def tmp_test():

    #rus = RU(4, 0.5)

    a = [ [1,2,3,4],[5,6,7,8],[1,2,4,3],[4,4,4,4]]

    a = [ a,a,a,a]

    a = [a]

    #a = torch.randn(1,4,4,4)

    a = torch.tensor(a, dtype=torch.float)



    

    #rus = ARFB(4, 0.5,0.5,0.5)

    #rus = HFM(4, 2)

    #rus = HPB(4, 2)

    #rus = CA(4,2)

    #rus = EMHA(4,2)

    #a = a.view(1, 4, 16)



    #rus = ET(4, 4, 4, 2)

    #rus = LTB(4,4,4,3,2,6)

    #rus = LCB(4,2,60)

    rus = ESRT(4,4,4,1)





    b = rus(a)

    print(a,"\n",b)







tensor_to_pil = ToPILImage()



model = ESRT(opt["pic_c"],opt["pic_h"],opt["pic_w"],opt["pic_up"])



dataloader_train = dataloader('/home/nicholas/Documents/dataset/DIV2K_train_LR_unknown/X2', 10000000, 5)

dataloader_test = dataloader('./dataset/images', 1, 5, test=True)



if os.path.exists("./model/esrt.m"):

    print( " model exists ... \n")

    new_training = 0

else:

    print( " model not exists, will create a new one. \n")

    new_training = 1



def display_img(img1, img2):

    plt.subplot(1, 2, 1)

    plt.imshow(img1)



    plt.subplot(1, 2, 2)

    plt.imshow(img2)



    plt.show()



def verse_normalize(img, means, stds):

    means = torch.Tensor(means)

    stds = torch.Tensor(stds)

    for i in range(3):

        img[:,i] = img[:,i] * stds[i] + means[i]

    return img



def train():

    global model



    if new_training == 1:

        print(" new training ...")

        model.init_weights()

    else:

        print(" continue training ...")

        model= torch.load("./model/esrt.m")



    loss = nn.MSELoss()





    running_loss = 0

    plt.ion()



    

    # every 200 patchs , update the learning rate.

    lr_index = 0

    optimizer = torch.optim.Adam(model.parameters(), lr = model.lr, betas = (0.9, 0.999))



    for X,Y in dataloader_train:

        lr_index += 1

        if lr_index % 2000 is 0:

            model.lr /= 2

            print( " model lr = ", model.lr)

            optimizer = torch.optim.Adam(model.parameters(), lr = model.lr, betas = (0.9, 0.999))



        '''

        Train

        '''

        model.train()



        _,c,h,w = X.shape



        #Y_1 = model(X[:,0].reshape(-1,1,h,w))

        Y_ = model(X)

        optimizer.zero_grad()

        #cost1 = loss(Y_1, R[:,0].reshape(-1,1,h,w))

        cost = loss(Y, Y_)

        cost.backward()

        optimizer.step()



        #running_loss += cost.item()

        print( "batch:{}, loss  is {:.4f} ".format(model.train_loops, cost) )



        model.train_loops += 1

        if model.train_loops %10 == 0:

            torch.save(model, "./model/esrt.m")

            test()



    plt.ioff()



def test():

    global model



    running_loss = 0

    #plt.ion()



    for X,Y in dataloader_test:

        _, _, dst_h, dst_w = Y.shape



        model.eval()



        _,c,h,w = X.shape



        Y_ = model(X)



        '''

        Display images

        '''



        OX = tensor_to_pil(X[0])

        del X

        plt.subplot(1, 2, 2)

        plt.imshow(OX)

        del OX



        Y_ = verse_normalize( Y_[0], means, stds )

        OY_ = tensor_to_pil(Y_)

        del Y_

        plt.subplot(1, 2, 1)

        plt.imshow(OY_)



        OY = verse_normalize(Y[0], means, stds)

        OY = tensor_to_pil(OY)



        transform_last = transforms.Compose([

            transforms.CenterCrop( (dst_h, dst_w) ),

            transforms.ToTensor(),

            #transforms.Normalize(means, stds)

            ])

        print(" Psnr = ", psnr(transform_last(OY_), transform_last(OY)) )

        del OY_

       

        plt.show()

        plt.pause(0.1)



    #plt.ioff()



def run(path):

    model= torch.load("./model/esrt.m")



    model.eval()

    if path is None:

        path = "./images"



    file_list = os.listdir(path)

    #plt.ion()

 

    # block size of SR images

    lh = opt['pic_h'] // opt['pic_up']

    lw = opt['pic_w'] // opt['pic_up']



    # the stride of moving blocks processing

    stride_h = lh - opt['cat_padding']  * 2

    stride_w = lw - opt['cat_padding']  * 2



    # the center sub block of model outputings

    y_top = (opt['cat_padding']) * opt['pic_up']

    y_left = (opt['cat_padding']) * opt['pic_up']

    y_down = (lh - opt['cat_padding']) * opt['pic_up']

    y_right = (lw - opt['cat_padding']) * opt['pic_up']



    for img_name in file_list:

        # get images

        img = Image.open(path+'/'+img_name)

        if len(img.getbands()) == 1:

            img = img.convert("RGB")

        c = 3



        w,h = img.size

        org_img = img

        org_w, org_h = w,h

        dst_w, dst_h = w * opt['pic_up'], h * opt['pic_up']



        # padding the edge, turn ( h X w ) to  ( uh X uw )

        pad_h = h + opt['cat_padding'] * 2

        pad_w = w + opt['cat_padding'] * 2

        pad_h_off = ((pad_h - 2*opt['cat_padding']) % stride_h)

        pad_w_off = ((pad_w - 2*opt['cat_padding']) % stride_w)



        if pad_h_off != 0:

            pad_h += stride_h - pad_h_off

        if pad_w_off != 0:

            pad_w += stride_w - pad_w_off

        transform_pad = transforms.Compose([

            transforms.CenterCrop( (pad_h, pad_w) ),

            ])

        img = transform_pad(img)

        w,h = img.size



        # size of SR images

        sr_h = h * opt['pic_up']

        sr_w = w * opt['pic_up']



        transform_mid = transforms.Compose([

            transforms.CenterCrop( (h, w) ),

            transforms.ToTensor(),

            transforms.Normalize(means, stds)

            ])

        transform_last = transforms.Compose([

            transforms.CenterCrop( (dst_h, dst_w) ),

            transforms.ToTensor(),

            #transforms.Normalize(means, stds), as the Y_ has been verse normalized.

            ])



        Y = torch.zeros( (1, 3, sr_h, sr_w) )

        img_ = transform_mid(img)

        

        top = 0

        down = top + lh

        while  down <= h:

            left = 0

            right = left + lw

            while right <= w:



                X_ = img_[:, top : down, left : right].detach()

                X_ = X_.reshape(-1, c, lh, lw)



                Y_ = model(X_)



                cent_top = (top + opt['cat_padding']) * opt['pic_up']

                cent_left = (left + opt['cat_padding']) * opt['pic_up']

                cent_down = (down - opt['cat_padding']) * opt['pic_up']

                cent_right = (right - opt['cat_padding']) * opt['pic_up']



                Y[0, :, cent_top : cent_down, cent_left: cent_right] = Y_[:, :, y_top: y_down, y_left: y_right]

                

                # move patchs to right

                left += stride_w

                right = left + lw

            # move patchs to down

            top += stride_h

            down = top + lh



 

        Y = verse_normalize(Y, means, stds)

        Y = Y[0]

        Y = tensor_to_pil(Y)

        Y = transform_last(Y)

        print(" Psnr = ", psnr(Y, transform_last(org_img)) )

        Y = tensor_to_pil(Y)

        Y.save("./ESRT.png")

        display_img( Y,  org_img)



        plt.pause(1)

        plt.clf()





"""



For training run this command:

python esrt.py train



For running run this command:

python esrt.py run

or

python esrt.py run YOUR_IMAGES_DIRECTORY



"""

if __name__ == '__main__':

    args = sys.argv[1:]

    print( args, len(args))

    if (len(args) == 1) & (args[0] == "train"):

        train()

    elif (len(args) == 1) & (args[0] == "run"):

        run()

    elif (len(args) == 2) & (args[0] == "run"):

        run( args[1] )

    else:

        tmp_test()






发表评论:
昵称

邮件地址 (选填)

个人主页 (选填)

内容