涵蓋18+ SOTA GAN實現,這個圖像生成領域的庫火了
GAN 自從被提出后,便迅速受到廣泛關注。我們可以將 GAN 分為兩類,一類是無條件下的生成;另一類是基于條件信息的生成。近日,來自韓國浦項科技大學的碩士生在 GitHub 上開源了一個項目,提供了條件 / 無條件圖像生成的代表性生成對抗網絡(GAN)的實現。

近日,機器之心在 GitHub 上看到了一個非常有意義的項目 PyTorch-StudioGAN,它是一個 PyTorch 庫,提供了條件 / 無條件圖像生成的代表性生成對抗網絡(GAN)的實現。據主頁介紹,該項目旨在提供一個統一的現代 GAN 平臺,這樣機器學習領域的研究者可以快速地比較和分析新思路和新方法等。
該項目的作者為韓國浦項科技大學的碩士生,他的研究興趣主要包括深度學習、機器學習和計算機視覺。

項目地址:https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
具體而言,該項目具有以下幾個顯著特征:
- 提供了大量 PyTorch 框架的 GAN 實現;
- 基于 CIFAR 10、Tiny ImageNet 和 ImageNet 數據集的 GAN 基準;
- 相較原始實現的更好的性能和更低的內存消耗;
- 提供完全最新 PyTorch 環境的預訓練模型;
- 支持多 GPU(DP、DDP 和多節點 DDP)、混合精度、同步批歸一化、LARS、Tensorboard 可視化和其他分析方法。
對于這個 PyTorch GAN 庫,有網友表示:「看上去很不錯!如果可以提供 top-k 等現代訓練實踐以及各種增強方法就更棒了。」對此,項目作者稱其會在 NeurIPS 論文提交截止日期之后,添加一些改進的方法,如 Sinha 等人的 Tok-K 訓練以及 Langevin 采樣和 SimCLR 增強。

此外,有網友詢問是否可以將該項目用于圖像之外的其他領域。作者表示可以,即使無法使用一些穩定器(如 diffaug、ada 等),依然可以通過調整 dataLoader 來訓練自己的模型。

18+ SOTA GAN 實現
如下圖所示,項目作者提供了 18 + 個 SOTA GAN 的實現,包括 DCGAN、LSGAN、GGAN、WGAN-WC、WGAN-GP、WGAN-DRA、ACGAN、ProjGAN、SNGAN、SAGAN、BigGAN、BigGAN-Deep、CRGAN、ICRGAN、LOGAN、DiffAugGAN、ADAGAN、ContraGAN 和 FreezeD。

cBN:條件批歸一化;AC:輔助分類器;PD:Projection 判別器;CL:對比學習。
其中,需要注意以下幾點:
- G/D_type 表示將標簽信息注入生成器或判別式的方式;
- EMA 表示生成器中應用更新后的指數移動平均線;
- Tiny ImageNet 數據集上的實驗使用的是 ResNet 架構而不是 CNN。
下圖中 StyleGAN2 為即將實現的 GAN 網絡,其中 AdaIN 表示自適應實例歸一化(Adaptive Instance Normalization)。

環境要求
- Anaconda
- Python >= 3.6
- 6.0.0 <= Pillow <= 7.0.0
- scipy == 1.1.0
- sklearn
- seaborn
- h5py
- tqdm
- torch >= 1.6.0
- torchvision >= 0.7.0
- tensorboard
- 5.4.0 <= gcc <= 7.4.0
- torchlars
用戶可以采用以下方法安裝推薦的環境:
- conda env create -f environment.yml -n studiogan
在 docker 中還可以采用以下方式:
- docker pull mgkang/studiogan:latest
以下是創建名字為「studioGAN」容器的命令,同樣也可以使用端口號為 6006 來連接 tensoreboard。
- docker run -it --gpus all --shm-size 128g -p 6006:6006 --name studioGAN -v /home/USER:/root/code --workdir /root/code mgkang/studiogan:latest /bin/bash
使用方法
使用 GPU 0 的情況下,在 CONFIG_PATH 中對于模型的訓練「-t」和評估「-e」進行了定義:
- CUDA_VISIBLE_DEVICES=0 python3 src/main.py -t -e -c CONFIG_PATH
在使用 GPU (0, 1, 2, 3) 和 DataParallel 情況下,在 CONFIG_PATH 中對于模型的訓練「-t」和評估「-e」進行了定義:
- CUDA_VISIBLE_DEVICES=0,1,2,3 python3 src/main.py -t -e -c CONFIG_PATH
在 python3 src/main.py 程序中查看可用選項,通過 Tensorboard 可以監控 IS、FID、F_beta、Authenticity Accuracies 以及最大奇異值:
- ~ PyTorch-StudioGAN/logs/RUN_NAME>>> tensorboard --logdir=./ --port PORT
可視化以及分析生成圖像
StudioGAN 支持圖像可視化、k 最近鄰分析、線性差值以及頻率分析。所有的結果保存在「./figures/RUN_NAME/*.png」中。
圖像可視化的代碼和示例如下:
- CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -iv -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

k 最近鄰分析,這里固定 K=7,第一列中是生成的圖像:
- CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -knn -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

線性插值(僅適用于有條件的 Big ResNet 模型 )的代碼和示例如下:
- CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -itp -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH
