簡(jiǎn)單使用PyTorch搭建GAN模型
以往人們普遍認(rèn)為生成圖像是不可能完成的任務(wù),因?yàn)榘凑諅鹘y(tǒng)的機(jī)器學(xué)習(xí)思路,我們根本沒(méi)有真值(ground truth)可以拿來(lái)檢驗(yàn)生成的圖像是否合格。
2014年,Goodfellow等人則提出生成 對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network, GAN) ,能夠讓我們完全依靠機(jī)器學(xué)習(xí)來(lái)生成極為逼真的圖片。GAN的橫空出世使得整個(gè)人工智能行業(yè)都為之震動(dòng),計(jì)算機(jī)視覺(jué)和圖像生成領(lǐng)域發(fā)生了巨變。
本文將帶大家了解 GAN的工作原理 ,并介紹如何 通過(guò)PyTorch簡(jiǎn)單上手GAN 。
GAN的原理
按照傳統(tǒng)的方法,模型的預(yù)測(cè)結(jié)果可以直接與已有的真值進(jìn)行比較。然而,我們卻很難定義和衡量到底怎樣才算作是“正確的”生成圖像。
Goodfellow等人則提出了一個(gè)有趣的解決辦法:我們可以先訓(xùn)練好一個(gè)分類(lèi)工具,來(lái)自動(dòng)區(qū)分生成圖像和真實(shí)圖像。這樣一來(lái),我們就可以用這個(gè)分類(lèi)工具來(lái)訓(xùn)練一個(gè)生成網(wǎng)絡(luò),直到它能夠輸出完全以假亂真的圖像,連分類(lèi)工具自己都沒(méi)有辦法評(píng)判真假。
按照這一思路,我們便有了GAN:也就是一個(gè) 生成器(generator) 和一個(gè) 判別器(discriminator) 。生成器負(fù)責(zé)根據(jù)給定的數(shù)據(jù)集生成圖像,判別器則負(fù)責(zé)區(qū)分圖像是真是假。GAN的運(yùn)作流程如上圖所示。
損失函數(shù)
在GAN的運(yùn)作流程中,我們可以發(fā)現(xiàn)一個(gè)明顯的矛盾:同時(shí)優(yōu)化生成器和判別器是很困難的??梢韵胂?,這兩個(gè)模型有著完全相反的目標(biāo):生成器想要盡可能偽造出真實(shí)的東西,而判別器則必須要識(shí)破生成器生成的圖像。
為了說(shuō)明這一點(diǎn),我們?cè)O(shè)D(x)為判別器的輸出,即x是真實(shí)圖像的概率,并設(shè)G(z)為生成器的輸出。判別器類(lèi)似于一種二進(jìn)制的分類(lèi)器,所以其目標(biāo)是使該函數(shù)的結(jié)果最大化:這一函數(shù)本質(zhì)上是非負(fù)的二元交叉熵?fù)p失函數(shù)。另一方面,生成器的目標(biāo)是最小化判別器做出正確判斷的機(jī)率,因此它的目標(biāo)是使上述函數(shù)的結(jié)果最小化。
因此,最終的損失函數(shù)將會(huì)是兩個(gè)分類(lèi)器之間的極小極大博弈,表示如下:理論上來(lái)說(shuō),博弈的最終結(jié)果將是讓判別器判斷成功的概率收斂到0.5。然而在實(shí)踐中,極大極小博弈通常會(huì)導(dǎo)致網(wǎng)絡(luò)不收斂,因此仔細(xì)調(diào)整模型訓(xùn)練的參數(shù)非常重要。
在訓(xùn)練GAN時(shí),我們尤其要注意學(xué)習(xí)率等超參數(shù),學(xué)習(xí)率比較小時(shí)能讓GAN在輸入噪音較多的情況下也能有較為統(tǒng)一的輸出。
計(jì)算環(huán)境
庫(kù)
本文將指導(dǎo)大家通過(guò)PyTorch搭建整個(gè)程序(包括torchvision)。同時(shí),我們將會(huì)使用Matplotlib來(lái)讓GAN的生成結(jié)果可視化。以下代碼能夠?qū)肷鲜鏊袔?kù):
- """
- Import necessary libraries to create a generative adversarial network
- The code is mainly developed using the PyTorch library
- """
- import time
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader
- from torchvision import datasets
- from torchvision.transforms import transforms
- from model import discriminator, generator
- import numpy as np
- import matplotlib.pyplot as plt
數(shù)據(jù)集
數(shù)據(jù)集對(duì)于訓(xùn)練GAN來(lái)說(shuō)非常重要,尤其考慮到我們?cè)贕AN中處理的通常是非結(jié)構(gòu)化數(shù)據(jù)(一般是圖片、視頻等),任意一class都可以有數(shù)據(jù)的分布。這種數(shù)據(jù)分布恰恰是GAN生成輸出的基礎(chǔ)。
為了更好地演示GAN的搭建流程,本文將帶大家使用最簡(jiǎn)單的MNIST數(shù)據(jù)集,其中含有6萬(wàn)張手寫(xiě)阿拉伯?dāng)?shù)字的圖片。
像 MNIST 這樣高質(zhì)量的非結(jié)構(gòu)化數(shù)據(jù)集都可以在 格物鈦 的 公開(kāi)數(shù)據(jù)集 網(wǎng)站上找到。事實(shí)上,格物鈦Open Datasets平臺(tái)涵蓋了很多優(yōu)質(zhì)的公開(kāi)數(shù)據(jù)集,同時(shí)也可以實(shí)現(xiàn) 數(shù)據(jù)集托管及一站式搜索的功能 ,這對(duì)AI開(kāi)發(fā)者來(lái)說(shuō),是相當(dāng)實(shí)用的社區(qū)平臺(tái)。
硬件需求
一般來(lái)說(shuō),雖然可以使用CPU來(lái)訓(xùn)練神經(jīng)網(wǎng)絡(luò),但最佳選擇其實(shí)是GPU,因?yàn)檫@樣可以大幅提升訓(xùn)練速度。我們可以用下面的代碼來(lái)測(cè)試自己的機(jī)器能否用GPU來(lái)訓(xùn)練:
- """
- Determine if any GPUs are available
- """
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
實(shí)現(xiàn)
網(wǎng)絡(luò)結(jié)構(gòu)
由于數(shù)字是非常簡(jiǎn)單的信息,我們可以將判別器和生成器這兩層結(jié)構(gòu)都組建成全連接層(fully connected layers)。
我們可以用以下代碼在PyTorch中搭建判別器和生成器:
- """
- Network Architectures
- The following are the discriminator and generator architectures
- """
- class discriminator(nn.Module):
- def __init__(self):
- super(discriminator, self).__init__()
- self.fc1 = nn.Linear(784, 512)
- self.fc2 = nn.Linear(512, 1)
- self.activation = nn.LeakyReLU(0.1)
- def forward(self, x):
- x = x.view(-1, 784)
- x = self.activation(self.fc1(x))
- x = self.fc2(x)
- return nn.Sigmoid()(x)
- class generator(nn.Module):
- def __init__(self):
- super(generator, self).__init__()
- self.fc1 = nn.Linear(128, 1024)
- self.fc2 = nn.Linear(1024, 2048)
- self.fc3 = nn.Linear(2048, 784)
- self.activation = nn.ReLU()
- def forward(self, x):
- x = self.activation(self.fc1(x))
- x = self.activation(self.fc2(x))
- x = self.fc3(x)
- x = x.view(-1, 1, 28, 28)
- return nn.Tanh()(x)
訓(xùn)練
在訓(xùn)練GAN的時(shí)候,我們需要一邊優(yōu)化判別器,一邊改進(jìn)生成器,因此每次迭代我們都需要同時(shí)優(yōu)化兩個(gè)互相矛盾的損失函數(shù)。
對(duì)于生成器,我們將輸入一些隨機(jī)噪音,讓生成器來(lái)根據(jù)噪音的微小改變輸出的圖像:
- """
- Network training procedure
- Every step both the loss for disciminator and generator is updated
- Discriminator aims to classify reals and fakes
- Generator aims to generate images as realistic as possible
- """
- for epoch in range(epochs):
- for idx, (imgs, _) in enumerate(train_loader):
- idx += 1
- # Training the discriminator
- # Real inputs are actual images of the MNIST dataset
- # Fake inputs are from the generator
- # Real inputs should be classified as 1 and fake as 0
- real_inputs = imgs.to(device)
- real_outputs = D(real_inputs)
- real_label = torch.ones(real_inputs.shape[0], 1).to(device)
- noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
- noise = noise.to(device)
- fake_inputs = G(noise)
- fake_outputs = D(fake_inputs)
- fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
- outputs = torch.cat((real_outputs, fake_outputs), 0)
- targets = torch.cat((real_label, fake_label), 0)
- D_loss = loss(outputs, targets)
- D_optimizer.zero_grad()
- D_loss.backward()
- D_optimizer.step()
- # Training the generator
- # For generator, goal is to make the discriminator believe everything is 1
- noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
- noise = noise.to(device)
- fake_inputs = G(noise)
- fake_outputs = D(fake_inputs)
- fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
- G_loss = loss(fake_outputs, fake_targets)
- G_optimizer.zero_grad()
- G_loss.backward()
- G_optimizer.step()
- if idx % 100 == 0 or idx == len(train_loader):
- print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
- if (epoch+1) % 10 == 0:
- torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
- print('Model saved.')
結(jié)果
經(jīng)過(guò)100個(gè)訓(xùn)練時(shí)期之后,我們就可以對(duì)數(shù)據(jù)集進(jìn)行可視化處理,直接看到模型從隨機(jī)噪音生成的數(shù)字:
我們可以看到,生成的結(jié)果和真實(shí)的數(shù)據(jù)非常相像??紤]到我們?cè)谶@里只是搭建了一個(gè)非常簡(jiǎn)單的模型,實(shí)際的應(yīng)用效果會(huì)有非常大的上升空間。
不僅是有樣學(xué)樣
GAN和以往機(jī)器視覺(jué)專(zhuān)家提出的想法都不一樣,而利用GAN進(jìn)行的具體場(chǎng)景應(yīng)用更是讓許多人贊嘆深度網(wǎng)絡(luò)的無(wú)限潛力。下面我們來(lái)看一下兩個(gè)最為出名的GAN延申應(yīng)用。
CycleGAN
朱俊彥等人2017年發(fā)表的CycleGAN能夠在沒(méi)有配對(duì)圖片的情況下將一張圖片從X域直接轉(zhuǎn)換到Y(jié)域,比如把馬變成斑馬、將熱夏變成隆冬、把莫奈的畫(huà)變成梵高的畫(huà)等等。這些看似天方夜譚的轉(zhuǎn)換CycleGAN都能輕松做到,并且結(jié)果非常準(zhǔn)確。
GauGAN
英偉達(dá)則通過(guò)GAN讓人們能夠只需要寥寥數(shù)筆勾勒出自己的想法,便能得到一張極為逼真的真實(shí)場(chǎng)景圖片。雖然這種應(yīng)用需要的計(jì)算成本極為高昂,但是GauGAN憑借它的轉(zhuǎn)換能力探索出了前所未有的研究和應(yīng)用領(lǐng)域。
結(jié)語(yǔ)
相信看到這里,你已經(jīng)知道了GAN的大致工作原理,并且能夠自己動(dòng)手簡(jiǎn)單搭建一個(gè)GAN了。