從VAE到擴散模型:一文解讀以文生圖新范式
1 前言
在發布DALL·E的15個月后,OpenAI在今年春天帶了續作DALL·E 2,以其更加驚艷的效果和豐富的可玩性迅速占領了各大AI社區的頭條。近年來,隨著生成對抗網絡(GAN)、變分自編碼器(VAE)、擴散模型(Diffusion models)的出現,深度學習已向世人展現其強大的圖像生成能力;加上GPT-3、BERT等NLP模型的成功,人類正逐步打破文本和圖像的信息界限。
在DALL·E 2中,只需輸入簡單的文本(prompt),它就可以生成多張1024*1024的高清圖像。這些圖像甚至可以將不合常理的語義表示,以超現實主義的形式創造出天馬行空的視覺效果,例如圖1中“寫實風格的騎馬的宇航員(An astronaut riding a horse in a photorealistic style)”。
圖1. DALL·E 2生成示例
本文將深入解讀DALL·E等新范式如何通過文本創造出眾多驚人的圖像,文中涵蓋大量背景知識和基礎技術的介紹,同樣適合初涉圖像生成領域的讀者。
2 圖像生成
圖2. 主流圖像生成方法
自2014年生成對抗網絡(GAN)誕生以來,圖像生成研究成為了深度學習乃至整個人工智能領域的重要前沿課題,現階段技術發展之強已達到以假亂真的程度。除了為人熟知的生成對抗網絡(GAN),主流方法還包括變分自編碼器(VAE)和基于流的生成模型(Flow-based models),以及近期頗受關注的擴散模型(Diffusion models)。借助圖2我們探尋一下各個方法的特點和區別。
2.1 生成對抗網絡(GAN)
GAN的全稱是 G enerative A dversarial N etworks,從名稱不難讀出“對抗(Adversarial)”是其成功之精髓。對抗的思想受博弈論啟發,在訓練生成器(Generator)的同時,訓練一個判別器(Discriminator)來判斷輸入是真實圖像還是生成圖像,兩者在一個極小極大游戲中相互博弈不斷變強,如式(1)。當從隨機噪聲生成足以“騙”過的圖像時,我們認為較好地擬合出了真實圖像的數據分布,通過采樣可以生成大量逼真的圖像。
GAN是生成式模型中應用最廣泛的技術,在圖像、視頻、語音和NLP等眾多數據合成場景大放異彩。除了直接從隨機噪聲生成內容外,我們還可以將條件(例如分類標簽)作為輸入加入生成器和判別器,使得生成結果符合條件輸入的屬性,讓生成內容得以控制。雖然GAN效果出眾,但由于博弈機制的存在,其訓練穩定性差且容易出現模式崩潰(Mode collapse),如何讓模型平穩地達到博弈均衡點,也是GAN的熱點研究話題。
2.2 變分自編碼器(VAE)
變分自編碼器(Variational Autoencoder)是自編碼器的一種變體,傳統的自編碼器旨在以無監督的方式訓練一個神經網絡,完成將原始輸入壓縮成中間表示和將恢復成兩個過程,前者通過編碼器(Encoder)將原始高維輸入轉換為低維隱層編碼,后者通過解碼器(Decoder)從編碼中重建數據。不難看出,自編碼器的目標是學習一個恒等函數,我們可以使用交叉熵(Cross-entropy)或者均方差(Mean Square Error)構建重建損失量化輸入和輸出的差異。如圖3所示,在上述過程中我們獲得了低緯度的隱層編碼,它捕捉了原始數據的潛在屬性,可以用于數據壓縮和特征表示。
圖3. 自編碼器的潛在屬性編碼
由于自編碼器僅關注隱層編碼的重建能力,其隱層空間分布往往是無規律和不均勻的,在連續的隱層空間隨機采樣或者插值得到一組編碼通常會產生無意義和不可解釋的生成結果。為了構建一個有規律的隱層空間,使得我們可以在不同潛在屬性上隨機地采樣和平滑地插值,最后通過解碼器生成有意義的圖像,研究者們在2014年提出了變分自編碼器。
變分自編碼器不再將輸入映射成隱層空間中的一個固定編碼,而是轉換成對隱層空間的概率分布估計,為了方便表示我們假設先驗分布是一個標準高斯分布。同樣的,我們訓練一個概率解碼器建模,實現從隱層空間分布到真實數據分布的映射。當給定一個輸入,我們通過后驗分布估計出關于分布的參數(多元高斯模型的均值和協方差),并在此分布上采樣,可使用重參數化技巧使采樣可導(為隨機變量),最后通過概率解碼器輸出關于的分布,如圖4所示。為了使生成圖像盡量真實,我們需要求解后驗分布,目標是最大化真實圖像的對數似然。
圖4. 變分自編碼器的采樣生成過程
遺憾的是,真實的后驗分布根據貝葉斯模型包含對在連續空間上的積分,是不可直接求解的。為了解決上述問題,變分自編碼器使用了變分推理的方法,引入一個可學習的概率編碼器去近似真實的后驗分布,使用KL散度度量兩個分布的差異,將這個問題從求解真實的后驗分布轉化為如何縮小兩個分布之間的距離。
我們省略中間推導過程,將上式展開得到式(2),
由于KL散度非負,我們可以將我們的最大化目標轉寫成式(3),
綜上,我們將關于概率編碼器和概率解碼器的定義為模型的損失函數,其負數形式稱為的證據下界(Evidence Lower Bound),最大化證據下界等效于最大化目標。上述變分過程是VAE及各種變體的核心思想,通過變分推理將問題轉化為最大化生成真實數據的證據下界。
2.3 基于流的生成模型(Flow-based models)
圖5. 基于流的生成過程
如圖5所示,假設原始數據分布可以通過一系列可逆的轉化函數從已知分布獲得,即。通過雅各布矩陣行列式和變量變化規則,我們可以直接估計真實數據的概率密度函數(式(4)),最大化可計算的對數似然。
是轉換函數的雅各布行列式,因此要求可逆之外還要求容易計算出其雅各布行列式。基于流的生成模型如Glow采用1x1可逆卷積進行精確的密度估計,在人臉生成上取得不錯的效果。
2.4 擴散模型(Diffusion models)
圖6. 擴散模型的擴散和逆向過程
擴散模型定義了正向和逆向兩個過程,正向過程或稱擴散過程是從真實數據分布采樣,逐步向樣本添加高斯噪聲,生成噪聲樣本序列,加噪過程可用方差參數控制,當時,可近似等同于一個高斯分布。其擴散過程是預設的可控過程,加噪過程可用條件分布表示為式(5),
從擴散過程的定義可以看出,我們可以在任意步長上使用上式采樣,
同樣我們也可以把擴散過程逆向,從高斯噪聲中采樣,學習一個模型來估計真實的條件概率分布,因此逆向過程可定義為式(7),
擴散模型的優化目標有多種選擇,例如在訓練過程中由于可以從正向過程直接計算,于是我們可以從預測的分布中采樣,采樣過程可以加入圖像分類和文本標簽作為條件輸入,用最小均方差優化重建損失,這個過程等效于自編碼器。
在去噪擴散概率模型DDPM中,作者通過重參數化技術構建了簡化版的噪聲預測模型損失(式(8)),在步長時輸入加噪數據 訓練模型去預測噪聲
,推理過程中使用
預測去噪數據 的高斯分布均值,實現人臉圖像去噪。
3 多模態表示學習
3.1 NLP on Transformer
圖7. BERT與GPT
BERT與GPT是近年來NLP領域中非常強大的預訓練語言模型,在文章生成,代碼生成,機器翻譯,Q&A等下游任務中取得巨大突破。兩者均使用了Transformer作為算法的主要框架,實現細節上略有不同(圖7)。
BERT本質上是一個雙向編碼器,通過Mask Language Model(MLM)和Next Sentence Prediction(NSP)兩個任務,使用自監督的方式學習文本的特征表示,可代替Word2Vec遷移至其他的學習任務中。GPT本質是自回歸解碼器,通過使用海量數據和不斷堆疊模型,最大化語言模型預測下個文本的似然值。重要的是,訓練過程中GPT的后序文本被mask使其在前序文本訓練預測時不可見,而在BERT中所有文本均相互可見并參與self-attention計算,BERT通過隨機mask或替換輸入,提升模型魯棒性和表達能力。
3.2 ViT(Vision Transformer)
Transformer在NLP領域取得的巨大成功,引發了研究者們對于其圖像特征表達能力的思考。與NLP不同,圖像信息是數量龐大和冗余的,直接使用Transformer建模會因Tokens數量過大導致模型無法學習。直到2020年研究者們提出了ViT,通過Patch和線性投影的方法,降低了圖像數據的維度,使用Transformer Encoder作為圖像編碼器輸出分類預測結果,取得了可觀的效果。
圖8. ViT
現如今Transformer已成為圖像處理領域新的研究對象,以其強大的潛力不斷挑戰CNN的地位。
3.3 CLIP
CLIP(Contrastive Language-Image Pretraining)是OpenAI提出的連接圖像和文本特征表示的對比學習方法。如圖9所示,CLIP成功將文本-圖像對通過Transformer編碼生成Tokens對,使用點積運算衡量相似度,由此對于每個文本我們獲得關于所有圖像的one-hot分類概率,反之每個圖像也能獲得關于所有文本的分類概率。在訓練過程中,我們對圖9(1)概率矩陣每行每列計算交叉熵損失進行優化。
圖9. CLIP
CLIP將文本和圖像的特征表示映射到同一空間,雖然沒有實現跨模態的信息傳遞,但作為特征壓縮、相似性度量和跨模態表示學習的方法,是十分有效的。直觀的,我們把圖像Tokens在標簽范圍生成的所有文本提示中與之特征最相似的輸出,即完成了一次圖像分類(圖9(2)),特別當圖像和標簽的數據分布未曾在訓練集出現過,CLIP仍然有零樣本(zero-shot)學習的能力。
4 跨模態圖像生成
經過前面兩章的介紹,我們系統性地回顧了圖像生成和多模態表示學習相關基礎技術,本章將介紹三個最新的跨模態圖像生成方法,解讀它們如何使用這些基礎技術進行建模。
4.1 DALL·E
DALL·E由OpenAI在2021年初提出,旨在訓練一個輸入文本到輸出圖像的自回歸解碼器。由CLIP的成功經驗可知,文本特征和圖像特征可以編碼在同一特征空間中,因此我們可以使用Transformer將文本和圖像特征自回歸建模為單個數據流(“autoregressively models the text and image tokens as a single stream of data”)。
DALL·E的訓練過程分成兩個階段,一是訓練一個變分自編碼器用于圖像編解碼,二是訓練一個文本和圖像的自回歸解碼器用于預測生成圖像的Tokens,如圖10所示。
圖10. DALL·E的訓練過程
推理過程則比較直觀,將文本Tokens用自回歸Transformer逐步解碼出圖像Tokens,解碼過程中我們可以通過分類概率采樣多組樣本,再將多組樣本Tokens輸入變分自編碼中解碼出多張生成圖像,并通過CLIP相似性計算排序擇優,如圖11所示。
圖11. DALL·E的推理過程
和VAE一樣我們用概率編碼器和概率解碼器,分別建模隱層特征的后驗概率分布和生成圖像的似然概率分布,使用建模由Transformer預測的文本和圖像的聯合概率分布作為先驗(在第一階段初始化為均勻分布),同理可得優化目標的證據下界,
在第一階段的訓練過程中,DALL·E使用了一個離散變分自編碼器(Discrete VAE)簡稱dVAE,是Vector Quantized VAE(VQ-VAE)的升級版。在VAE中我們用一個概率分布刻畫了連續的隱層空間,通過隨機采樣得到隱層編碼,但是這個編碼并不像離散的語言文字具有確定性。為了學習圖像隱層空間的“語言”,VQ-VAE使用了一組可學習的向量量化表示隱層空間,這個量化的隱層空間我們稱為Embedding Space或者Codebook/Vocabulary。VQ-VAE的訓練過程和預測過程旨在尋找與圖像編碼向量距離最近的隱層向量,再將映射得到的向量語言解碼成圖像(圖12),損失函數由三部分構成,分別優化重構損失、更新Embedding Space和更新編碼器,梯度終止。
圖12. VQ-VAE
VQ-VAE由于最近鄰選擇假設使其后驗概率是確定的,即距離最近的隱層向量概率為1其余為0,不具有隨機性;距離最近的向量選擇過程不可導,使用了straight-through estimator方法將的梯度傳遞給。
圖13. dVAE
為了優化上述問題,DALL·E使用Gumbel-Softmax構建了新的dVAE(圖13),解碼器的輸出變為Embedding Space上32*32個K=8192維分類概率,在訓練過程中對分類概率的Softmax計算加入噪聲引入隨機性,使用逐步減小的溫度讓概率分布近似one-hot編碼,對隱層向量的選擇重參數化使其可導(式(11)),推理過程中仍取最近鄰。
PyTorch實現中可設置hard=True輸出近似的one-hot編碼,同時通過 y_hard = y_hard - y_soft.detach() + y_soft 保持可導。
當第一階段訓練完成后,我們可以固定dVAE對于每對文本-圖像生成預測目標的圖像Tokens。在第二階段訓練過程中,DALL·E使用BPE方法將文本先編碼成和圖像Tokens相同維度d=3968的文本Tokens,再將文本Tokens和圖像Tokens Concat到一起,加入位置編碼和Padding編碼,使用Transformer Encoder進行自回歸預測,如圖14所示。為了提升計算速度,DALL·E還采用了Row、Column、Convolutional三種稀疏化的attention mask機制。
圖14. DALL·E的自回歸解碼器
基于上述實現,DALL·E可根據文本輸入不僅可生成“真實”的圖像,還可進行融合創作、場景理解和風格轉化,如圖15。此外,DALL·E在零樣本和專業領域的效果可能變差,且生成的圖像分辨率(256*256)較低。
圖15. DALL·E的多種生成場景
4.2 DALL·E 2
為了進一步提升圖像生成質量和探求文本-圖像特征空間的可解釋性,OpenAI結合擴散模型和CLIP在2022年4月提出了DALL·E 2,不僅將生成尺寸增加到了1024*1024,還通過特征空間的插值操作,可視化了文本-圖像特征空間的遷移過程。
如圖16所示,DALL·E 2將CLIP對比學習得到的text embedding、image embedding作為模型輸入和預測對象,具體過程是學習一個先驗Prior,從text預測對應的image embedding,文章分別用自回歸Transformer和擴散模型兩種方式訓練,后者在各數據集上表現更好;再學習一個擴散模型解碼器UnCLIP,可看做是CLIP圖像編碼器的逆向過程,將Prior預測得到的image embedding作為條件加入中實現控制,text embedding和文本內容作為可選條件,為了提升分辨率UnCLIP還增加了兩個上采樣解碼器(CNN網絡)用于逆向生成更大尺寸的圖像。
圖16. DALL·E 2
在Prior的擴散模型訓練中,DALL·E 2使用了一個Transformer Decoder預測擴散過程,輸入序列為BPE-encoded text + text embedding + timestep embedding+ 當前加噪的image embedding,預測去噪的image embedding,用MSE構建損失函數,
DALL·E 2為避免模型對于特定的文本標簽產生定向類型的生成結果,降低了特征豐富性,對于擴散模型的預測條件增加限制,保證無分類器引導(classifier-free guidance)。例如,在Prior和UnCLIP的擴散模型訓練中,對于加入text embedding等條件設置drop概率,使生成過程不完成依賴條件輸入。因此在逆向生成過程中,我們可以通過image embedding采樣生成同一張圖像的不同變體同時保持基本特征,還可以分別在image embedding和text embedding插值,控制插值比例可生成平滑遷移的可視化結果,如圖17所示。
圖17. DALL·E 2可實現的圖像特征保持和遷移
DALL·E 2對Prior和UnCLIP的有效性做了大量驗證實驗,例如通過三種方式1)只將文本內容輸入UnCLIP生成模型;2)只將文本內容和text embedding輸入UnCLIP生成模型;3)在上述方法基礎上加入Prior預測的image embedding,三種方法的生成效果逐漸提升驗證了Prior有效性。另外,DALL·E 2使用了PCA對隱層空間的embedding降維,隨著維度降低生成圖像的語義特征逐漸減弱。最后,DALL·E 2在MS-COCO數據集上對比了其他方法,取得了FID= 10.39最好的生成質量(圖18)。
圖18. DALL·E 2在MS-COCO數據集上的對比結果
4.3 ERNIE-VILG
ERNIE-VILG是百度文心在2022年初提出的中文場景的文本-圖像雙向生成模型。
圖19. ERNIE-VILG
ERNIE-VILG的思路和DALL·E相似,通過預訓練的變分自編碼器編碼圖像特征,使用Transformer將文本Tokens和圖像Tokens自回歸預測,主要不同點在于:
- ERNIE-VILG依靠百度文心平臺技術,可以處理中文場景;
- 除了Text-to-Image自回歸過程,還建模了Image-to-Text方向過程,且雙向過程參數共享;
- Text-to-Image自回歸過程中,Text Tokens之間不做mask處理;
- 圖像編解碼使用了VQ-VAE和VQ-GAN,通過map&flatten將的圖像解碼過程與自回歸過程連接,實現了端到端訓練。
ERNIE-VILG的另一個強大之處是,在中文場景可以處理多個物體和復雜位置關系的生成問題,如圖20。
圖20. ERNIE-VILG的生成示例
四、總結
本文通過實例解讀了最新的以文生圖的新范式,包含變分自編碼器和擴散模型等生成方法的應用,CLIP等文本-圖像潛在空間表示學習的方法,以及離散化和重參數化等建模技術。
現如今文本到圖像的生成技術有較高的門檻,其訓練成本遠超人臉識別、機器翻譯、語音合成等單模態方法,以DALL·E為例,OpenAI收集并標注了2.5億對樣本,使用了1024塊V100 GPU訓練了120億參數量的模型。此外,圖像生成領域一直存在種族歧視、暴力情色、敏感隱私等問題。從2020年開始,越來越多的AI團隊投入到跨模態生成研究中,不久的將來我們可能在真實世界和生成世界中真假難分。