輕松構建 PyTorch 生成對抗網絡(GAN)
展現在您眼前的這幅圖像中的人物并非自真實存在,其實她是由一個機器學習模型創造出來的虛擬人物。圖片取自 維基百科的 GAN 條目,畫面細節豐富、色彩逼真,讓人印象深刻。
生成對抗網絡(GAN)是一種生成式機器學習模型,它被廣泛應用于廣告、游戲、娛樂、媒體、制藥等行業,可以用來創造虛構的人物、場景,模擬人臉老化,圖像風格變換,以及產生化學分子式等等。下面兩張圖片,分別展示了圖片到圖片轉換的效果,以及基于語義布局合成景物的效果。


本文將引領讀者,從工程實踐角度出發,借助 AWS 機器學習相關云計算服務,基于 PyTorch 機器學習框架,構建第一個生成對抗網絡,開啟全新的、有趣的機器學習和人工智能體驗。
還等什么,讓我們馬上開始吧!
主要內容
- 課題及方案概覽
- 模型的開發環境
- 生成對抗網絡模型
- 模型的訓練和驗證
- 結論與總結
課題及方案概覽
下面顯示的兩組手寫體數字圖片,您是否能從中夠辨認出由計算機生成的『手寫』字體是其中哪一組?


本文的課題是用機器學習方法『模仿手寫字體』,為了完成這個課題,您將親手體驗生成對抗網絡的設計和實現?!耗7率謱懽煮w』與人像生成的基本原理和工程流程基本是一致的,雖然它們的復雜性和精度要求有一定差距,但是通過解決『模仿手寫字體』問題,可以為生成對抗網絡的原理和工程實踐打下基礎,進而可以逐步嘗試和探索更加復雜先進的網絡架構和應用場景。
《生成對抗網絡》(GAN)由 Ian Goodfellow 等人在 2014年提出,它是一種深度神經網絡架構,由一個生成網絡和一個判別網絡組成。生成網絡產生『假』數據,并試圖欺騙判別網絡;判別網絡對生成數據進行真偽鑒別,試圖正確識別所有『假』數據。在訓練迭代的過程中,兩個網絡持續地進化和對抗,直到達到平衡狀態(參考:納什均衡),判別網絡無法再識別『假』數據,訓練結束。
2016年,Alec Radford 等發表的論文 《深度卷積生成對抗網絡》(DCGAN)中,開創性地將卷積神經網絡應用到生成對抗網絡的模型算法設計當中,替代了全鏈接層,提高了圖片場景里訓練的穩定性。
Amazon SageMaker 是 AWS 完全托管的機器學習服務,數據處理和機器學習訓練工作可以通過 Amazon SageMaker 快速、輕松地完成,訓練好的模型可以直接部署到全托管的生產環境中。Amazon SageMaker 提供了托管的 Jupyter Notebook 實例,通過 SageMaker SDK 與 AWS 的多種云服務集成,方便您訪問數據源,進行探索和分析。SageMaker SDK 是一套開放源代碼的 Amazon SageMaker 的開發包,可以協助您很好的使用 Amazon SageMaker 提供的托管容器鏡像,以及 AWS 的其他云服務,如計算和存儲資源。

如上圖所示,訓練用數據將來自 Amazon S3 的存儲桶;訓練用的框架和托管算法以容器鏡像的形式提供服務,在訓練時與代碼結合;模型代碼運行在 Amazon SageMaker 托管的計算實例中,在訓練時與數據結合;訓練輸出物將進入 Amazon S3 專門的存儲桶里。后面的講解中,我們會了解到如何通過 SageMaker SDK 使用這些資源。
我們將用到 Amazon SageMaker、Amazon S3 、Amazon EC2 等 AWS 服務,會產生一定的云資源使用費用。
模型的開發環境
創建Notebook實例
請打開 Amazon SageMaker 的儀表板(點擊打開 北京區域 | 寧夏區域 ),請點擊Notebook instances 按鈕進入筆記本實例列表。

如果您是第一次使用Amazon SageMaker,您的 Notebook instances 列表將顯示為空列表,此時您需點擊 Create notebook instance 按鈕來創建全新 Jupyter Notebook 實例。

進入 Create notebook instance 頁面后,請在 Notebook instance name 字段里輸入實例名字,本文將使用 MySageMakerInstance 作為實例名,您可以選用您認為合適的名字。本文將使用默認的實例類型,因此 Notebook instance type 選項將保持為 *ml.t2.medium*。如果您是第一次使用Amazon SageMaker,您需要創建一個 IAM role,以便筆記本實例能夠訪問 Amazon S3 服務。請在 IAM role 選項點擊為 Create a new role。Amazon SageMaker 將創建一個具有必要權限的角色,并將這個角色分配給正在創建的實例。另外,根據您的實際情況,您也可以選擇一個已經存在的角色。

在 Create an IAM role 彈出窗口里,您可以選擇 *Any S3 bucket*,這樣筆記本實例將能夠訪問您賬戶里的所有桶。另外,根據您的需要,您還可以選擇 Specific S3 buckets并輸入桶名。點擊 Create role 按鈕,這個新角色將被創建。

此時,可以看到 Amazon SageMaker 為您創建了一個名字類似 *
AmazonSageMaker-ExecutionRole-**** 的角色。對于其他字段,您可以使用默認值,請點擊 Create notebook instance 按鈕,創建實例。

回到 Notebook instances 頁面,您會看到 MySageMakerInstance 筆記本實例顯示為 Pending 狀態,這個將持續2分鐘左右,直到轉為 InService 狀態。

編寫第一行代碼
點擊 Open JupyterLab 鏈接,在新的頁面里,您將看到熟悉的 Jupyter Notebook 加載界面。本文默認以 JupyterLab 筆記本作為工程環境,根據您的需要,可以選擇使用傳統的 Jupyter 筆記本。

您將通過點擊 conda_pytorch_p36, 筆記本圖標來創建一個叫做 Untitled.ipynb 的筆記本,您可以稍后更改它的名字。另外,您也可以通過 File > New > Notebook 菜單路徑,并選擇 conda_pytorch_p36 作為 Kernel 來創建這個筆記本。

在新建的 Untitled.ipynb 筆記本里,我們將輸入第一行指令如下,
- import torch
- print(f"Hello PyTorch {torch.__version__}")
源代碼下載
請在筆記本中輸入如下指令,下載代碼到實例本地文件系統。
下載完成后,您可以通過 File browser 瀏覽源代碼結構。

本文涉及到的代碼和筆記本均通過 Amazon SageMaker 托管的 Python 3.6、PyTorch 1.4 和 JupyterLab 驗證。本文涉及到的代碼和筆記本可以通過 這里獲取。
生成對抗網絡模型
算法原理
DCGAN模型的生成網絡包含10層,它使用跨步轉置卷積層來提高張量的分辨率,輸入形狀為 (batchsize, 100) ,輸出形狀為 (batchsize, 64, 64, 3)。換句話說,生成網絡接受噪聲向量,然后經過不斷變換,直到生成最終的圖像。
判別網絡也包含10層,它接收 (64, 64, 3) 格式的圖片,使用2D卷積層進行下采樣,最后傳遞給全鏈接層進行分類,分類結果是 1 或 0,即真與假。

DCGAN 模型的訓練過程大致可以分為三個子過程。

首先, Generator 網絡以一個隨機數作為輸入,生成一張『假』圖片;接下來,分別用『真』圖片和『假』圖片訓練 Discriminator 網絡,更新參數;最后,更新 Generator 網絡參數。
代碼分析
項目目錄 byos-pytorch-gan 的文件結構如下,
文件 model.py 中包含 3 個類,分別是 生成網絡 Generator 和 判別網絡 Discriminator。
- class Generator(nn.Module):
- ...
- class Discriminator(nn.Module):
- ...
- class DCGAN(object):
- """
- A wrapper class for Generator and Discriminator,
- 'train_step' method is for single batch training.
- """
- ...
文件 train.py 用于 Generator 和 Discriminator 兩個神經網絡的訓練,主要包含以下幾個方法,
- def parse_args():
- ...
- def get_datasets(dataset_name, ...):
- ...
- def train(dataloader, hps, ...):
- ...
模型的調試
開發和調試階段,可以從 Linux 命令行直接運行 train.py 腳本。超參數、輸入數據通道、模型和其他訓練產出物存放目錄都可以通過命令行參數指定。
- python dcgan/train.py --dataset qmnist \
- --model-dir '/home/myhome/byom-pytorch-gan/model' \
- --output-dir '/home/myhome/byom-pytorch-gan/tmp' \
- --data-dir '/home/myhome/byom-pytorch-gan/data' \
- --hps '{"beta1":0.5,"dataset":"qmnist","epochs":15,"learning-rate":0.0002,"log-interval":64,"nc":1,"nz":100,"sample-interval":100}'
這樣的訓練腳本參數設計,既提供了很好的調試方法,又是與 SageMaker Container 集成的規約和必要條件,很好的兼顧了模型開發的自由度和訓練環境的可移植性。
模型的訓練和驗證
請查找并打開名為 dcgan.ipynb 的筆記本文件,訓練過程將由這個筆記本介紹并執行,本節內容代碼部分從略,請以筆記本代碼為準。
互聯網環境里有很多公開的數據集,對于機器學習的工程和科研很有幫助,比如算法學習和效果評價。我們將使用 QMNIST 這個手寫字體數據集訓練模型,最終生成逼真的『手寫』字體效果圖樣。
數據準備
PyTorch 框架的 torchvision.datasets 包提供了QMNIST 數據集,您可以通過如下指令下載 QMNIST 數據集到本地備用。
- from torchvision import datasets
- dataroot = './data'
- trainset = datasets.QMNIST(root=dataroot, train=True, download=True)
- testset = datasets.QMNIST(root=dataroot, train=False, download=True)
Amazon SageMaker 為您創建了一個默認的 Amazon S3 桶,用來存取機器學習工作流程中可能需要的各種文件和數據。 我們可以通過 SageMaker SDK 中 sagemaker.session.Session 類的 default_bucket 方法獲得這個桶的名字。
- from sagemaker.session import Session
- sess = Session()
- # S3 bucket for saving code and model artifacts.
- # Feel free to specify a different bucket here if you wish.
- bucket = sess.default_bucket()
SageMaker SDK 提供了操作 Amazon S3 服務的包和類,其中 S3Downloader 類用于訪問或下載 S3 里的對象,而 S3Uploader 則用于將本地文件上傳至 S3。您將已經下載的數據上傳至 Amazon S3,供模型訓練使用。模型訓練過程不要從互聯網下載數據,避免通過互聯網獲取訓練數據的產生的網絡延遲,同時也規避了因直接訪問互聯網對模型訓練可能產生的安全風險。
- from sagemaker.s3 import S3Uploader as s3up
- s3_data_location = s3up.upload(f"{dataroot}/QMNIST", f"s3://{bucket}/data/qmnist")
訓練執行
通過
sagemaker.getexecutionrole() 方法,當前筆記本可以得到預先分配給筆記本實例的角色,這個角色將被用來獲取訓練用的資源,比如下載訓練用框架鏡像、分配 Amazon EC2 計算資源等等。
訓練模型用的超參數可以在筆記本里定義,實現與算法代碼的分離,在創建訓練任務時傳入超參數,與訓練任務動態結合。
- hps = {
- "learning-rate": 0.0002,
- "epochs": 15,
- "dataset": "qmnist",
- "beta1": 0.5,
- "sample-interval": 200,
- "log-interval": 64
- }
sagemaker.pytorch 包里的 PyTorch 類是基于 PyTorch 框架的模型擬合器,可以用來創建、執行訓練任務,還可以對訓練完的模型進行部署。參數列表中, train_instance_type 用來指定CPU或者GPU實例類型,訓練腳本和包括模型代碼所在的目錄通過 source_dir 指定,訓練腳本文件名必須通過 entry_point 明確定義。這些參數將和其余參數一起被傳遞給訓練任務,他們決定了訓練任務的運行環境和模型訓練時參數。
- from sagemaker.pytorch import PyTorch
- estimator = PyTorch(role=role,
- entry_point='train.py',
- source_dir='dcgan',
- output_path=s3_model_artifacts_location,
- code_location=s3_custom_code_upload_location,
- train_instance_count=1,
- train_instance_type='ml.c5.xlarge',
- train_use_spot_instances=True,
- train_max_wait=86400,
- framework_version='1.4.0',
- py_version='py3',
- hyperparameters=hps)
請特別注意 train_use_spot_instances 參數,True 值代表您希望優先使用 SPOT 實例。由于機器學習訓練工作通常需要大量計算資源長時間運行,善用 SPOT 可以幫助您實現有效的成本控制,SPOT 實例價格可能是按需實例價格的 20% 到 60%,依據選擇實例類型、區域、時間不同實際價格有所不同。
您已經創建了 PyTorch 對象,下面可以用它來擬合預先存在 Amazon S3 上的數據了。下面的指令將執行訓練任務,訓練數據將以名為 QMNIST 的輸入通道的方式導入訓練環境。訓練開始執行過程中,Amazon S3 上的訓練數據將被下載到模型訓練環境的本地文件系統,訓練腳本 train.py 將從本地磁盤加載數據進行訓練。
- # Start training
- estimator.fit({'QMNIST': s3_data_location}, wait=False)
根據您選擇的訓練實例不同,訓練過程中可能持續幾十分鐘到幾個小時不等。建議設置 wait 參數為 False ,這個選項將使筆記本與訓練任務分離,在訓練時間長、訓練日志多的場景下,可以避免筆記本上下文因為網絡中斷或者會話超時而丟失。訓練任務脫離筆記本后,輸出將暫時不可見,可以執行如下代碼,筆記本將獲取并載入此前的訓練回話,
- %%time
- from sagemaker.estimator import Estimator
- # Attaching previous training session
- training_job_name = estimator.latest_training_job.name
- attached_estimator = Estimator.attach(training_job_name)
由于的模型設計考慮到了GPU對訓練加速的能力,所以用GPU實例訓練會比CPU實例快一些,例如,p3.2xlarge 實例大概需要15分鐘左右,而 c5.xlarge 實例則可能需要6小時以上。目前模型不支持分布、并行訓練,所以多實例、多CPU/GPU并不會帶來更多的訓練速度提升。
訓練完成后,模型將被上傳到 Amazon S3 里,上傳位置由創建 PyTorch 對象時提供的 output_path 參數指定。
模型的驗證
您將從 Amazon S3 下載經過訓練的模型到筆記本所在實例的本地文件系統,下面的代碼將載入模型,然后輸入一個隨機數,獲得推理結果,以圖片形式展現出來。執行如下指令加載訓練好的模型,并通過這個模型產生一組『手寫』數字字體。
- from helper import *
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- from dcgan.model import Generator
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- params = {'nz': nz, 'nc': nc, 'ngf': ngf}
- model = load_model(Generator, params, "./model/generator_state.pth", device=device)
- img = generate_fake_handwriting(model, batch_size=batch_size, nz=nz, device=device)
- plt.imshow(np.asarray(img))

結論與總結
近些年成長快速的 PyTorch 框架正在得到廣泛的認可和應用,越來越多的新模型采用 PyTorch 框架,也有模型被遷移到 PyTorch 上,或者基于 PyTorch 被完整再實現。生態環境持續豐富,應用領域不斷拓展,PyTorch 已成為事實上的主流框架之一。Amazon SageMaker 與多種 AWS 服務緊密集成,比如,各種類型和尺寸的 Amazon EC2 計算實例、Amazon S3、Amazon ECR 等等,為機器學習工程實踐提供了端到端的、一致的體驗。Amazon SageMaker 持續支持主流機器學習框架,PyTorch 是這其中之一。用 PyTorch 開發的機器學習算法和模型,可以輕松移植到 Amazon SageMaker 的工程和服務環境里,進而利用 Amazon SageMaker 全托管的 Jupyter Notebook、訓練容器鏡像、服務容器鏡像、訓練任務管理、部署環境托管等功能,簡化機器學習工程復雜度,提高生產效率,降低運維成本。
DCGAN 是生成對抗網絡領域中具里程碑意義的一個,是現今很多復雜生成對抗網絡的基石。文首提到的 StyleGAN,用文本合成圖像的 StackGAN,從草圖生成圖像的Pix2pix,以及互聯網上爭議不斷的 DeepFakes 等等,都有DCGAN的影子。相信通過本文的介紹和工程實踐,對您了解生成對抗網絡的原理和工程方法會有所幫助。