阅读 90

Pytorch袖珍手册之八

pytorch pocket reference

原书下载地址:
我用阿里云盘分享了「OReilly.PyTorch.Pocket.R...odels.149209000X.pdf」,你可以不限速下载?
复制这段内容打开「阿里云盘」App 即可获取
链接:https://www.aliyundrive.com/s/NZvnGbTYr6C

第四章 基于已有网络设计进行神经网络应用开发

这一章主要通过三个例子来表现Pytorch在神经网络开发应用的便捷性及高效性。

  • 基于迁移学习的图片分类
  • 自然语言处理里的情感分析
  • GAN,生成图片

GAN,基于Fashion MNIST数据生成图片

深度学习另一个应用场景就是生成学习(generative learning),主要通过模型来生成数据,如图片,音乐,文本和时间系列数据等。

在本章节的例子中,我们通过构建一个GAN模型生成如Fashion-MNIST里的图片数据。

  • 数据预处理 Data Processing
    跟之前的两个例子有些不同,GAN模型主要通过学习训练数据来生成一些跟训练数据相似的数据,以达到“以假弄真”的作用。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from torchvision.utils import make_grid


CODING_SIZE = 100
BATCH_SIZE = 32
IMAGE_SIZE = 64

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 定义transforms
transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor()
    ]
)

dataset = datasets.FashionMNIST(
    './data',
    train=True,
    download=True,
    transform=transform
)

dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0 # 在windows下无法用多进程
    )


# 数据可视化,查看数据情况
data_batch, labels_batch = next(iter(dataloader))
# batch_size 32, 4*8
grid_img = make_grid(data_batch, nrow=8)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
Fashion MNIST
  • 模型构建 生成器&判别器
# 生成器模型构建
class Generator(nn.Module):
    def __init__(self, coding_sz):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # 反卷积操作
            nn.ConvTranspose2d(coding_sz, 1024, 1, 1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 1, 4, 2, 1),
            nn.Tanh()       
        )
        
    def forwar(self, input):
        return self.net(input)
    
netG = Generator(CODING_SIZE).to(device)


# 判别器模型构建
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1024, 4, 2, 1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, 4, 1, 0),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.net(input)
    
netD = Discriminator().to(device)

# DCGAN paper found that it helps to initialize the weights 
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        

netG.apply(weights_init)
netD.apply(weights_init)
“”“
Generator(
  (net): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Tanh()
  )
)
============================================================================================
Discriminator(
  (net): Sequential(
    (0): Conv2d(1, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2)
    (11): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)
”“”
  • 训练模型
# 训练模型
"""
In each epoch, we will first train the discriminator with a real batch of data, 
then use the generator to create a fake batch, and then train the discriminator with the generated fake batch of data. 
Lastly, we will train the generator NN to produce better fakes.
"""
# 定义损失函数及优化器
criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))

# 定义真假数据标签值
real_labels = torch.full((BATCH_SIZE,), 1., dtype=torch.float, device=device)
fake_labels = torch.full((BATCH_SIZE,), 0., dtype=torch.float, device=device)


G_losses = []
D_losses = []
D_real = []
D_fake = []
N_EPOCHS = 5

z = torch.randn((BATCH_SIZE, 100)).view(-1, 100, 1, 1).to(device)
test_out_images = []

for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch}')
    for i, batch in enumerate(dataloader):
        if (i%200==0):
            print(f'batch: {i} of {len(dataloader)}')
    
        # 训练判别器,基于真数据
        netD.zero_grad()
        real_images = batch[0].to(device) *2. - 1.
        output = netD(real_images).view(-1) 
        errD_real = criterion(output, real_labels)
        D_x = output.mean().item()

        # 训练判别器,基于假数据
        noise = torch.randn((BATCH_SIZE, CODING_SIZE))
        noise = noise.view(-1,100,1,1).to(device)
        fake_images = netG(noise)
        output = netD(fake_images).view(-1) 
        errD_fake = criterion(output, fake_labels)
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        errD.backward(retain_graph=True) 
        optimizerD.step()

        # 训练生成器,产生更逼真数据
        netG.zero_grad()
        output = netD(fake_images).view(-1) 
        errG = criterion(output, real_labels)
        errG.backward() 
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # 保存中间变量值,用于后续画图
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        D_real.append(D_x)
        D_fake.append(D_G_z2)

    test_images = netG(z).to('cpu').detach() 
    test_out_images.append(test_images)
        
         
grid_img = make_grid((test_out_images[0]+1.)/2., nrow=8)
plt.imshow(grid_img.permute(1, 2, 0)) 
image.png
  • 训练过程中,生成判别模型损失值情况
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
image.png
  • Discriminator Results
plt.figure(figsize=(10,5))
plt.title("Discriminator Results")
plt.plot(D_real,label="D(real)")
plt.plot(D_fake,label="D(fake)")
plt.xlabel("iterations")
plt.ylabel("Percentage Real")
plt.legend()
plt.show()
image.png
  • 模型生成图片数据效果
grid_img = make_grid((test_out_images[4]+1.)/2., nrow=8)
plt.imshow(grid_img.permute(1, 2, 0))
image.png

作者:深思海数_willschang

原文链接:https://www.jianshu.com/p/d968d81bcca8

文章分类
后端
文章标签
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐