对抗生成网络:指马为斑马

发布于: 雪球转发:0回复:0喜欢:0

生成对抗网络GAN是去年以来比较火的一个技术,它通过一个生成网络来形成新的内容,再通过一个判别网络来判断生成的内容是否是想要的内容。

一个简单的实现如下(非原创):

# ResNetGenerator
import torch
import torch.nn as nn   
class ResNetBlock(nn.Module):
      def __init__(self, dim):
         super(ResNetBlock, self).__init__()
         self.conv_block = self.build_conv_block(dim)
      def build_conv_block(self, dim):
          conv_block = []
          conv_block += [nn.ReflectionPad2d(1)]
          conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                        nn.InstanceNorm2d(dim),
                        nn.ReLU(True)]
          conv_block += [nn.ReflectionPad2d(1)]
          conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                        nn.InstanceNorm2d(dim)]
          return nn.Sequential(*conv_block)
      def forward(self, x):
         out = x + self.conv_block(x)
         return out
   class ResNetGenerator(nn.Module):
      def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
         assert(n_blocks >= 0)
         super(ResNetGenerator, self).__init__()
         self.input_nc = input_nc
         self.output_nc = output_nc
         self.ngf = ngf
         model = [nn.ReflectionPad2d(3),
                  nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                  nn.InstanceNorm2d(ngf),
                  nn.ReLU(True)]
          n_downsampling = 2
         for i in range(n_downsampling):
             mult = 2**i
             model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,                                  stride=2, padding=1, bias=True),
                       nn.InstanceNorm2d(ngf * mult * 2),
                       nn.ReLU(True)]
         mult = 2**n_downsampling
         for i in range(n_blocks):
             model += [ResNetBlock(ngf * mult)]
          for i in range(n_downsampling):
             mult = 2**(n_downsampling - i)
             model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                          kernel_size=3, stride=2,
                                          padding=1, output_padding=1,
                                          bias=True),
                       nn.InstanceNorm2d(int(ngf * mult / 2)),
                       nn.ReLU(True)]
          model += [nn.ReflectionPad2d(3)]
         model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
         model += [nn.Tanh()]
          self.model = nn.Sequential(*model)
      def forward(self, input):
         return self.model(input)

我们可以用它来实现把马变成斑马,首先创建一个实例

netG = ResNetGenerator()

然后下载一个训练好的模型参数给我们的netG

!git clone 网页链接
model_path = 'dlwpt-code/data/p1ch2/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

将模型调整为评估模式

netG.eval()

随便找一张马的图片,读取图片

from PIL import Image
from torchvision import transforms
img = Image.open("horse.jpg")img

对图片进行一些处理

preprocess = transforms.Compose([transforms.Resize(256),
                                                          transforms.ToTensor()])
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
batch_out = netG(batch_t)

指马为斑马

out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
out_img