自回歸解碼加速64倍,谷歌提出圖像合成新模型MaskGIT
生成式 transformer 在合成高保真和高分辨率圖像方面得到了快速普及。但迄今為止最好的生成式 transformer 模型仍是將圖像視為一系列 token,并按照光柵掃描順序(即逐行)解碼圖像。然而這種策略既不是最優的,也不高效。
近日,來自谷歌研究院的研究者提出了一種使用雙向 transformer 解碼器的新型圖像合成模型 MaskGIT。在訓練期間,MaskGIT 通過關注各個方向的 token 來學習預測隨機掩碼 token。在推理階段,模型首先同時生成圖像的所有 token,然后以上一次生成為條件迭代地細化圖像。實驗表明,MaskGIT 在 ImageNet 數據集上顯著優于 SOTA transformer 模型,并將自回歸解碼的速度提高了 64 倍。
論文地址:https://arxiv.org/abs/2202.04200
此外,該研究還表明 MaskGIT 可以輕松擴展到各種圖像編輯任務,例如修復、外推和圖像處理。
相關研究
先前的模型 VQVAE 提出分兩個階段在潛在空間中生成圖像。
第一個階段稱為 tokenization,其中嘗試將圖像壓縮到離散的潛在空間中,這一階段主要包含三個部分:
一個編碼器 E ,負責學習將圖像 x∈ tokenize 成潛在嵌入 E(x);一個用于最近鄰查找 codebook
,以將嵌入量化為視覺 token;一個解碼器 G,它根據視覺 token e 預測重建圖像
。
第二個階段首先使用深度自回歸模型預測視覺 token 的潛在先驗,然后使用第一階段的解碼器將 token 序列映射到圖像像素中。
這種兩階段范式是很有效的,因此幾種常用的方法都遵循了這種范式,例如 DALL-E、VQGAN。其中,VQGAN 在第一階段增加了對抗性損失和感知損失以提高圖像保真度。
MaskGIT
上述使用兩階段范式的方法由于仍然采用自回歸模型,因此第二階段的解碼時間與 token 序列長度成比例。而本研究的目標是設計一種利用并行解碼和雙向生成的新圖像合成范式,遵循上述兩階段方案并改進第二階段。第一階段采用與 VQGAN 模型相同的設置,并將潛在的改進留給未來工作的 tokenization 步驟;對于第二階段,研究者提出通過掩碼視覺 token 建模(Masked Visual Token Modeling,MVTM 學習雙向 transformer。
訓練中的 MVTM
該研究用表示將圖像輸入到 VQ 編碼器獲得的潛在 token,其中 N 是重構后的 token 矩陣的長度,
是對應的二進制掩碼。在訓練期間,該研究采樣 token 的子集,并用一個特殊的 [MASK] token 替代它們。如果 m_i=1,就用 [MASK] 取代 token y_i;如果 m_i=0,y_i 保留。
采樣過程由掩碼調度函數(mask scheduling function) 進行參數化,然后按照如下步驟:
首先從 0 到 1 采樣一個比率,然后在 Y 中統一選擇 個 token 來放置掩碼,其中 N 是長度。掩碼調度顯著影響了圖像的生成質量。
迭代解碼
在自回歸解碼中,token 是根據先前生成的輸出順序生成的。這個過程是不可并行的,而圖像的 token 長度通常比語言長得多,因此速度非常慢。該研究提出了一種新型解碼方法,其中圖像中的所有 token 都是同時并行生成的,這基于 MTVM 的雙向自注意力。
理論上講,該模型能夠推斷出所有 token 并在單次傳遞中生成整個圖像,但訓練任務的不一致給該研究帶來了挑戰。為了在推理時生成圖像,該研究從一個空白 canvas 開始,所有 token 都被掩碼,即。該研究提出的迭代解碼方法,每次迭代的算法運行步驟如下:
1. 預測2. 采樣3. 掩碼調度4. 掩碼
掩碼設計
研究者發現圖像的生成質量受到掩碼設計的顯著影響。該方法通過一個掩碼調度函數對掩碼過程進行建模,該函數負責計算給定潛在 token 的掩碼比率。在推理期間,函數
用
的輸入代表解碼的進度;在訓練期間,該研究在 [0,1) 中隨機采樣一個比率 r 來模擬各種解碼場景。
實驗
該研究從質量、效率和靈活性方面對 MaskGIT 在圖像生成方面進行了實驗評估。
類條件圖像合成
該研究在 ImageNet 256 X 256 和 ImageNet 512 X 512 上評估了 MaskGIT 模型在類條件(class-conditional)圖像合成任務上的性能,主要結果如下表 1 所示。
質量。在 ImageNet 256 X 256 上,不使用任何特殊的采樣策略,MaskGIT 在 FID 和 IS 方面都顯著優于 VQGAN。
速度。該研究通過評估每個模型生成樣本所需的步驟數(前向傳遞)來評估模型速度。如表 1 所示,在所有基于非 GAN 的模型中,MaskGIT 在兩種分辨率上所需的步驟最少。
為了進一步證實 MaskGIT 和自回歸模型之間的速度差異,該研究對 MaskGIT 和 VQGAN 的解碼過程進行了運行時比較。如下圖 4 所示,MaskGIT 將 VQGAN 顯著加速了 30-64 倍,隨著圖像分辨率(以及輸入 token 長度)的增加,加速變得更加明顯。
多樣性。除了樣本質量外,該研究還將分類準確率得分 (CAS) 和 Precision/Recall 作為評估樣本多樣性的兩個指標。與 BigGAN 的樣本相比,MaskGIT 的樣本更加多樣化,具有更多種光照、姿態、規模和語境,如下圖 5 所示。
圖像編輯應用
該研究展示了 MaskGIT 在三個圖像編輯任務上的直接應用:類條件圖像編輯、圖像修復和圖像擴展(outpainting)。如果將任務看作對初始二進制掩碼 M MaskGIT 在其迭代解碼中使用約束,那么這三個任務幾乎都可以輕松地轉換為 MaskGIT 可以處理的任務。
該研究表明,無需修改架構或任何特定于任務的訓練,MaskGIT 就能夠在所有三個應用程序上產生非常優秀的結果。此外,MaskGIT 在圖像修復和擴展方面獲得了與專用模型相當的性能。
在類條件圖像編輯任務上,該研究定義了一個新的類條件圖像編輯任務來展示 MaskGIT 的靈活性。模型在給定類的邊界框內重新生成特定內容,同時保留語境,即框外的內容。由于違背了預測順序,因此自回歸方法是不可行的。
然而,對于 MaskGIT,如果將邊界框區域視為迭代解碼算法的初始掩碼的輸入,這個問題就迎刃而解了。下圖 6 給出了一些示例結果。
表 2 比較了幾種方法的定量結果。MaskGIT 在 FID 和 IS 中均以顯著優勢擊敗 DeepFill 和 HiFill,同時獲得接近 SOTA 修復方法 CoModGAN 的分數。
如下圖 7 所示,MaskGIT 還能夠在給定相同輸入和不同種子的情況下合成不同的結果。
消融實驗
為了驗證新設計的效用,該研究在 ImageNet 256×256 的默認設置上進行了消融實驗。MaskGIT 的一個關鍵設計是用于訓練和迭代解碼的掩碼調度函數,實驗結果如下表 3 和圖 8 所示。
值得注意的是,如圖 8 所示,在相同的設置下,更多的迭代不一定更好:隨著迭代次數 T 的增加,除了對數函數在整個過程中都表現不佳以外,其他所有函數都達到了一個「sweet spot」位置,即模型的性能在再次惡化之前達到峰值。