[DL]ESRT: Transformer for Single Image Super

2022-5-31 写技术

"""
Author: Nicholas Xiao
Blog: log.anycle.com
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import math
import random
import copy

means = [0.485, 0.456, 0.406]
stds = [ 0.229, 0.224, 0.225]


opt = {
    "pic_c":32, # middle convolution layers is 32
    "pic_h":96, # low responsity images' height is 48
    "pic_w":96,
    "pic_up":2, # upscale
    "heads":8,
    "vec_dim":512,
    "N":6,
    "x_vocab_len":0,
    "y_vocab_len":0,
    "sentence_len":80,
    "batchs":1000,
    "batch_size":10,
    "pad":0,
    "sof":1,
    "eof":2,
    "cat_padding": 5,
}



"""
# errors
def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/1.0) - (img2/1.0)   ) ** 2
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(255.0**2/mse)
"""

def psnr(img1, img2):
    img1 = Variable(img1)
    img2 = Variable(img2)
    mse = (  (img1/255.0) - (img2/255.0)   ) ** 2
    mse = np.mean( np.array(mse) )
    if mse < 1.0e-10:
        return 100
    return 10 * math.log10(1/math.sqrt(mse))


def rgb2ycbcr(rgb_image):
    """convert rgb into ycbcr"""
    if len(rgb_image.shape)!=3 or rgb_image.shape[2]!=3:
        return rgb_image
        #raise ValueError("input image is not a rgb image")
    rgb_image = rgb_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    shift_matrix = np.array([16, 128, 128])
    ycbcr_image = np.zeros(shape=rgb_image.shape)
    w, h, _ = rgb_image.shape
    for i in range(w):
        for j in range(h):
            ycbcr_image[i, j, :] = np.dot(transform_matrix, rgb_image[i, j, :]) + shift_matrix
    return ycbcr_image

def ycbcr2rgb(ycbcr_image):
    """convert ycbcr into rgb"""
    if len(ycbcr_image.shape)!=3 or ycbcr_image.shape[2]!=3:
        return ycbcr_image
        #raise ValueError("input image is not a rgb image")
    ycbcr_image = ycbcr_image.astype(np.float32)
    transform_matrix = np.array([[0.257, 0.564, 0.098],
                                 [-0.148, -0.291, 0.439],
                                 [0.439, -0.368, -0.071]])
    transform_matrix_inv = np.linalg.inv(transform_matrix)
    shift_matrix = np.array([16, 128, 128])
    rgb_image = np.zeros(shape=ycbcr_image.shape)
    w, h, _ = ycbcr_image.shape
    for i in range(w):
        for j in range(h):
            rgb_image[i, j, :] = np.dot(transform_matrix_inv, ycbcr_image[i, j, :]) - np.dot(transform_matrix_inv, shift_matrix)
    return rgb_image.astype(np.uint8)

class dataloader:
    def __init__(self, path, batchs,  batch_size, test= False, order = False):
        self.test = test
        self.batchs_cnt = 0
        self.batchs = batchs
        self.batch_size = batch_size
        self.path = path
        self.file_list = os.listdir(self.path)
        if order:
            self.file_list = sorted(self.file_list)
        else:
            random.shuffle(self.file_list)
        self.file_number = len(self.file_list)

        self.file_start = 0
        self.file_stop = self.file_number  - 1

        self.file_start = int(np.floor(self.file_start))
        self.file_stop = int(np.floor(self.file_stop))

        self.file_idx = self.file_start

        self.downsample = transforms.Compose([
            transforms.Resize( (opt['pic_h'] // opt['pic_up'], opt['pic_w'] // opt['pic_up']), Image.BICUBIC),
            #transforms.CenterCrop( pic_height),
            #transforms.Scale( pic_height // 2, interpolation=Image.BICUBIC),
            #transforms.Resize( (int(pic_height/2),int(pic_width/2)), Image.BILINEAR),
            #transforms.RandomResizedCrop( int(pic_height/2), scale=(1,1) ),
            #transforms.RandomResizedCrop( pic_height, scale=(0.08,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])

        self.transform = transforms.Compose([
            #transforms.CenterCrop( pic_height),
            transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),
            #transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])

        self.center_down = transforms.Compose([
            transforms.CenterCrop( (opt['pic_h'] // opt['pic_up'], opt['pic_w'] // opt['pic_up']) ),
            #transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),
            #transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])



        self.center = transforms.Compose([
            transforms.CenterCrop( (opt['pic_h'], opt['pic_w'])),
            #transforms.Resize( (opt['pic_h'], opt['pic_w']), Image.BILINEAR),
            #transforms.RandomResizedCrop( pic_height, scale=(1,1) ),
            #transforms.Grayscale(num_output_channels = 3),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])

    def get_len(self):
        return self.file_stop  - self.file_start

    def __iter__(self):
        return self

    def __next__(self):
        if (self.batchs_cnt >= self.batchs) & (self.batchs > 0):
            self.batchs_cnt = 0
            raise StopIteration
        self.batchs_cnt += 1

        X = []
        Y = []
        for i in range( self.batch_size):
            X_, Y_ = self._next()
            X.append(X_)
            Y.append(Y_)

        X = torch.stack(X, 0)
        Y = torch.stack(Y, 0)
        return X, Y

    def _next(self):
        if self.file_idx >= self.file_stop:
            self.file_idx = self.file_start

        file_path = self.path + '/' + self.file_list[self.file_idx]

        if self.test:
            file_path_hr = self.path + '/' + self.file_list[self.file_idx+1]

            self.file_idx += 2

            X = Image.open(file_path)
            Y = Image.open(file_path_hr)

            if len(X.getbands()) == 1:
                X = X.convert("RGB")
            if len(Y.getbands()) == 1:
                Y = Y.convert("RGB")

            X = self.center_down(X)
            Y = self.center(Y)

        else:
            self.file_idx += 1

            Y = Image.open(file_path)

            if len(Y.getbands()) == 1:
                Y = Y.convert("RGB")

            X = self.downsample(Y)
            Y = self.transform(Y)

        return X,Y


def get_clones(module, N):
    return nn.ModuleList( [ copy.deepcopy(module) for i in range(N) ] )

# RU
class RU(nn.Module):
    def __init__(self, c, lamda):
        super().__init__()
        mid_c = int(c//2)
        self.lamda = nn.Parameter(torch.FloatTensor([lamda]))
        self.reduction = nn.Conv2d(c, mid_c, kernel_size=1, stride=1, padding=0)
        self.expansion = nn.Conv2d(mid_c, c, kernel_size=1, stride=1, padding=0)
    
    def forward(self, x):
        y = self.reduction(x)
        y = self.expansion(y)
        y = y * self.lamda
        y = y + x
        return y

# Adaptive Residual Feature Block
class ARFB(nn.Module):
    def __init__(self, c, lamda_res, lamda_x, lamda_ru):
        super().__init__()
        self.ru = RU(c, lamda_ru)
        self.conv1 = nn.Conv2d( 2*c, 2*c, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d( 2*c, c, kernel_size=3, stride=1, padding=1)
        self.lamda_res = nn.Parameter(torch.FloatTensor([lamda_res]))
        self.lamda_x = nn.Parameter(torch.FloatTensor([lamda_x]))

    def forward(self, x):
        y1 = self.ru(x)
        y2 = self.ru(y1)
        y2 = torch.cat( (y1, y2), dim=1)
        y2 = self.conv1(y2)
        y2 = self.conv3(y2)
        y2 = y2 * self.lamda_res
        y = x * self.lamda_x + y2
        return y

# High-frequency Filtering Module
class HFM(nn.Module):
    def __init__(self, c, k):
        super().__init__()
        self.avgPool = nn.AvgPool2d(kernel_size=k, stride=k, padding=0)
        self.upsample = nn.Upsample(scale_factor=k, mode='nearest')

    def forward(self, TL):
        TA = self.avgPool(TL)
        TU = self.upsample(TA)
        y = TL - TU
        return y

class CA(nn.Module):
    def __init__(self, c, reduction=16):
        super().__init__()
        self.avgPool= nn.AdaptiveAvgPool2d(1)
        self.convDu = nn.Sequential(
                nn.Conv2d(c, c//reduction, kernel_size=1, stride=1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(c//reduction, c, kernel_size=1, stride=1, padding=0, bias=True),
                nn.Sigmoid()
            )

    def forward(self, x):
        y = self.avgPool(x)
        y = self.convDu(y)
        return x * y

class HPB(nn.Module):
    def __init__(self, c, k, lamda_res=0.5, lamda_x=0.5, lamda_ru=0.5):
        super().__init__()
        self.arfb = ARFB(c, lamda_res, lamda_x, lamda_ru)
        self.hfm = HFM(c,k)
        self.conv1 = nn.Conv2d(2*c, c, kernel_size=1, stride=1, padding=0)
        self.ca = CA(c, 2)
    def forward(self, x):
        # main line
        y = self.arfb(x)
        y = self.hfm(y)
        
        # half y
        y2 = y / 2
        for i in range(5):
            y2 = self.arfb(y2)
        y2 = y * 2

        # main line
        y = self.arfb(y)
        y = torch.cat( (y, y2), dim=1)
        y = self.conv1(y)
        y = self.ca(y)
        y = self.arfb(y)
        y = x + y

        return y

class LCB(nn.Module):
    def __init__(self,c,k=2, hpb_num=3):
        super().__init__()
        self.hpb_num = hpb_num
        self.layers = get_clones( HPB(c, k), hpb_num)

    def forward(self, x):
        for i in range(self.hpb_num):
            x = self.layers[i](x)
        return x


# attention
def attention(q,k,v,dk,mask=None,dropout=None):
    m = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(dk)

    if mask is not None:
        mask = mask.unsqueeze(1)
        m = m.masked_fill(mask == 0, -1e9)

    #print(m)
    m = F.softmax(m, dim = -1)

    if dropout is not None:
        m = dropout(m)

    #print(m)
    m = torch.matmul( m, v)

    return m

# B*C*N
class EMHA(nn.Module):
    def __init__(self, c, heads=8, s=4, dropout=0.1):
        super().__init__()
        self.s = s
        self.c1 = c // 2
        self.dk = self.c1 // heads
        self.h = heads
        self.dropout = nn.Dropout( dropout)

        self.reduction = nn.Conv1d(c, self.c1, kernel_size=1, stride=1, padding=0)
        self.expansion = nn.Conv1d(self.c1, c, kernel_size=1, stride=1, padding=0)
        self.lq = nn.Linear(self.c1, self.c1)
        self.lk = nn.Linear(self.c1, self.c1)
        self.lv = nn.Linear(self.c1, self.c1)

    def forward(self, x):
        batch_size = x.size(0)

        # reduce (B C N ) to (B C1 N)
        y = self.reduction(x)
        # swap (B C1 N)  to (B N C1)
        y = y.transpose(1,2)
        # linear then to (B N M C1//M) , M = h, dk = c1//h
        q = self.lq(y).view(batch_size, -1, self.h, self.dk)
        k = self.lk(y).view(batch_size, -1, self.h, self.dk)
        v = self.lv(y).view(batch_size, -1, self.h, self.dk)

        # split fetatures to s(4) patchs
        b,n,m,c_ = q.shape
        q = q.view(b, self.s, -1, m, c_)
        k = k.view(b, self.s, -1, m, c_)
        v = v.view(b, self.s, -1, m, c_)

        # swap (B S N_s M C1//M) to (B S M N_s C1//M)
        q = q.transpose(2,3)
        k = k.transpose(2,3)
        v = v.transpose(2,3)

        # swap (B S M N_s C1//M) to (S B M N_s C1//M)
        q = q.transpose(0,1)
        k = k.transpose(0,1)
        v = v.transpose(0,1)

        # get attention of (S B M N_s C1//M)
        atn = []
        for i in range( self.s ):
            atn.append( attention(q[i],k[i],v[i], self.dk, None, self.dropout) )
        # get attention of (B M N C1//M)
        atn = torch.cat( atn, dim = 2 )

        # cat to ( B N C1)
        ret = atn.transpose(1,2).contiguous().view(batch_size, -1, self.c1)
        # swap to ( B C1 N)
        ret = ret.transpose(1,2)
        # expand (B C1 N) to (B C N)
        ret = self.expansion(ret)

        return ret

class Norm(nn.Module):
    def __init__(self, d_model, eps = 1e-6):
        super().__init__()
        self.size = d_model
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim = -1, keepdim=True)) / ( x.std(dim=-1, keepdim=True) + self.eps ) + self.bias
        return norm

# Efficient Transformer
# (B C H W)  to (B C N), N = d_model
class ET(nn.Module):
    def __init__(self, c, d_model, heads):
        super().__init__()
        self.c = c
        self.d_model = d_model
        self.norm1 = Norm( self.d_model )
        self.norm2 = Norm( self.d_model )
        self.emha = EMHA(c, heads)
        self.mlp = nn.Linear(self.d_model, self.d_model)

    def forward(self, x):
        y2 = x
        y1 = y2

        # main line
        y2 = self.norm1(y2)
        y2 = self.emha(y2)
        
        # residual
        y2 = y1 + y2
        y1 = y2

        # main line
        y2 = self.norm2(y2)
        y2 = self.mlp(y2)

        # residual
        y2 = y1 + y2

        return y2

class LTB(nn.Module):
    def __init__(self, c, h=int(opt['pic_h']/opt['pic_up']), w=int(opt['pic_w']/opt['pic_up']), k=3, heads=opt['heads'], et_num=1):
        super().__init__()
        self.et_num = et_num
        self.unfold = nn.Unfold(kernel_size=k, dilation=1, stride=1, padding=1)
        self.layers = get_clones( ET(c*k*k, h*w, heads), et_num)
        self.fold = nn.Fold( (h,w), kernel_size=k, dilation=1, stride=1, padding=1)

    def forward(self, x):
        y = self.unfold(x)
        for i in range(self.et_num):
            y = self.layers[i](y)

        y = self.fold(y)

        return y
        
class ESRT(nn.Module):
    def __init__(self, c, h, w, upscale):
        super().__init__()
        self.train_loops = 0
        self.lr = 2e-4
        self.c1 = int( c // math.pow(2, upscale) )
        self.ltb = LTB(c)
        self.lcb = LCB(c)
        self.conv1 = nn.Conv2d( 3, c, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d( c, c, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d( self.c1, 3, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d( self.c1, 3, kernel_size=3, stride=1, padding=1)
        self.up = nn.PixelShuffle(upscale)
        self.upscale = upscale

    def forward(self, x):
        y1 = self.conv1(x)

        # main line
        y2 = self.lcb(y1)
        y2 = self.ltb(y2)
        y2 = self.conv2(y2)

        y2 = self.up(y2)
        y2 = self.conv3(y2)

        # sub line
        y1 = self.up(y1)
        y1 = self.conv4(y1)

        y2 = y2 + y1
        return y2

    def init_weights(self):

        print(" init weights \n")


def tmp_test():
    #rus = RU(4, 0.5)
    a = [ [1,2,3,4],[5,6,7,8],[1,2,4,3],[4,4,4,4]]
    a = [ a,a,a,a]
    a = [a]
    #a = torch.randn(1,4,4,4)
    a = torch.tensor(a, dtype=torch.float)

    
    #rus = ARFB(4, 0.5,0.5,0.5)
    #rus = HFM(4, 2)
    #rus = HPB(4, 2)
    #rus = CA(4,2)
    #rus = EMHA(4,2)
    #a = a.view(1, 4, 16)

    #rus = ET(4, 4, 4, 2)
    #rus = LTB(4,4,4,3,2,6)
    #rus = LCB(4,2,60)
    rus = ESRT(4,4,4,1)


    b = rus(a)
    print(a,"\n",b)



tensor_to_pil = ToPILImage()

model = ESRT(opt["pic_c"],opt["pic_h"],opt["pic_w"],opt["pic_up"])

dataloader_train = dataloader('/home/nicholas/Documents/dataset/DIV2K_train_LR_unknown/X2', 10000000, 5)
dataloader_test = dataloader('./dataset/images', 1, 5, test=True)

if os.path.exists("./model/esrt.m"):
    print( " model exists ... \n")
    new_training = 0
else:
    print( " model not exists, will create a new one. \n")
    new_training = 1

def display_img(img1, img2):
    plt.subplot(1, 2, 1)
    plt.imshow(img1)

    plt.subplot(1, 2, 2)
    plt.imshow(img2)

    plt.show()

def verse_normalize(img, means, stds):
    means = torch.Tensor(means)
    stds = torch.Tensor(stds)
    for i in range(3):
        img[:,i] = img[:,i] * stds[i] + means[i]
    return img

def train():
    global model

    if new_training == 1:
        print(" new training ...")
        model.init_weights()
    else:
        print(" continue training ...")
        model= torch.load("./model/esrt.m")

    loss = nn.MSELoss()


    running_loss = 0
    plt.ion()

    
    # every 200 patchs , update the learning rate.
    lr_index = 0
    optimizer = torch.optim.Adam(model.parameters(), lr = model.lr, betas = (0.9, 0.999))

    for X,Y in dataloader_train:
        lr_index += 1
        if lr_index % 2000 is 0:
            model.lr /= 2
            print( " model lr = ", model.lr)
            optimizer = torch.optim.Adam(model.parameters(), lr = model.lr, betas = (0.9, 0.999))

        '''
        Train
        '''
        model.train()

        _,c,h,w = X.shape

        #Y_1 = model(X[:,0].reshape(-1,1,h,w))
        Y_ = model(X)
        optimizer.zero_grad()
        #cost1 = loss(Y_1, R[:,0].reshape(-1,1,h,w))
        cost = loss(Y, Y_)
        cost.backward()
        optimizer.step()

        #running_loss += cost.item()
        print( "batch:{}, loss  is {:.4f} ".format(model.train_loops, cost) )

        model.train_loops += 1
        if model.train_loops %10 == 0:
            torch.save(model, "./model/esrt.m")
            test()

    plt.ioff()

def test():
    global model

    running_loss = 0
    #plt.ion()

    for X,Y in dataloader_test:
        _, _, dst_h, dst_w = Y.shape

        model.eval()

        _,c,h,w = X.shape

        Y_ = model(X)

        '''
        Display images
        '''

        OX = tensor_to_pil(X[0])
        del X
        plt.subplot(1, 2, 2)
        plt.imshow(OX)
        del OX

        Y_ = verse_normalize( Y_[0], means, stds )
        OY_ = tensor_to_pil(Y_)
        del Y_
        plt.subplot(1, 2, 1)
        plt.imshow(OY_)

        OY = verse_normalize(Y[0], means, stds)
        OY = tensor_to_pil(OY)

        transform_last = transforms.Compose([
            transforms.CenterCrop( (dst_h, dst_w) ),
            transforms.ToTensor(),
            #transforms.Normalize(means, stds)
            ])
        print(" Psnr = ", psnr(transform_last(OY_), transform_last(OY)) )
        del OY_
       
        plt.show()
        plt.pause(0.1)

    #plt.ioff()

def run(path):
    model= torch.load("./model/esrt.m")

    model.eval()
    if path is None:
        path = "./images"

    file_list = os.listdir(path)
    #plt.ion()
 
    # block size of SR images
    lh = opt['pic_h'] // opt['pic_up']
    lw = opt['pic_w'] // opt['pic_up']

    # the stride of moving blocks processing
    stride_h = lh - opt['cat_padding']  * 2
    stride_w = lw - opt['cat_padding']  * 2

    # the center sub block of model outputings
    y_top = (opt['cat_padding']) * opt['pic_up']
    y_left = (opt['cat_padding']) * opt['pic_up']
    y_down = (lh - opt['cat_padding']) * opt['pic_up']
    y_right = (lw - opt['cat_padding']) * opt['pic_up']

    for img_name in file_list:
        # get images
        img = Image.open(path+'/'+img_name)
        if len(img.getbands()) == 1:
            img = img.convert("RGB")
        c = 3

        w,h = img.size
        org_img = img
        org_w, org_h = w,h
        dst_w, dst_h = w * opt['pic_up'], h * opt['pic_up']

        # padding the edge, turn ( h X w ) to  ( uh X uw )
        pad_h = h + opt['cat_padding'] * 2
        pad_w = w + opt['cat_padding'] * 2
        pad_h_off = ((pad_h - 2*opt['cat_padding']) % stride_h)
        pad_w_off = ((pad_w - 2*opt['cat_padding']) % stride_w)

        if pad_h_off != 0:
            pad_h += stride_h - pad_h_off
        if pad_w_off != 0:
            pad_w += stride_w - pad_w_off
        transform_pad = transforms.Compose([
            transforms.CenterCrop( (pad_h, pad_w) ),
            ])
        img = transform_pad(img)
        w,h = img.size

        # size of SR images
        sr_h = h * opt['pic_up']
        sr_w = w * opt['pic_up']

        transform_mid = transforms.Compose([
            transforms.CenterCrop( (h, w) ),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
            ])
        transform_last = transforms.Compose([
            transforms.CenterCrop( (dst_h, dst_w) ),
            transforms.ToTensor(),
            #transforms.Normalize(means, stds), as the Y_ has been verse normalized.
            ])

        Y = torch.zeros( (1, 3, sr_h, sr_w) )
        img_ = transform_mid(img)
        
        top = 0
        down = top + lh
        while  down <= h:
            left = 0
            right = left + lw
            while right <= w:

                X_ = img_[:, top : down, left : right].detach()
                X_ = X_.reshape(-1, c, lh, lw)

                Y_ = model(X_)

                cent_top = (top + opt['cat_padding']) * opt['pic_up']
                cent_left = (left + opt['cat_padding']) * opt['pic_up']
                cent_down = (down - opt['cat_padding']) * opt['pic_up']
                cent_right = (right - opt['cat_padding']) * opt['pic_up']

                Y[0, :, cent_top : cent_down, cent_left: cent_right] = Y_[:, :, y_top: y_down, y_left: y_right]
                
                # move patchs to right
                left += stride_w
                right = left + lw
            # move patchs to down
            top += stride_h
            down = top + lh

 
        Y = verse_normalize(Y, means, stds)
        Y = Y[0]
        Y = tensor_to_pil(Y)
        Y = transform_last(Y)
        print(" Psnr = ", psnr(Y, transform_last(org_img)) )
        Y = tensor_to_pil(Y)
        Y.save("./ESRT.png")
        display_img( Y,  org_img)

        plt.pause(1)
        plt.clf()


"""

For training run this command:
python esrt.py train

For running run this command:
python esrt.py run
or
python esrt.py run YOUR_IMAGES_DIRECTORY

"""
if __name__ == '__main__':
    args = sys.argv[1:]
    print( args, len(args))
    if (len(args) == 1) & (args[0] == "train"):
        train()
    elif (len(args) == 1) & (args[0] == "run"):
        run()
    elif (len(args) == 2) & (args[0] == "run"):
        run( args[1] )
    else:
        tmp_test()



标签: python machine_learning DL

发表评论:

Powered by anycle 湘ICP备15001973号-1