顏水成/程明明新作!Sora核心組件DiT訓練提速10倍,Masked Diffusion Transformer V2開源
DiT作為效果驚艷的Sora的核心技術之一,利用Difffusion Transfomer 將生成模型擴展到更大的模型規模,從而實現高質量的圖像生成。
然而,更大的模型規模導致訓練成本飆升。
為此,來自Sea AI Lab、南開大學、昆侖萬維2050研究院的顏水成和程明明研究團隊在ICCV 2023提出的Masked Diffusion Transformer利用mask modeling表征學習策略通過學習語義表征信息來大幅加速Diffusion Transfomer的訓練速度,并實現SoTA的圖像生成效果。
圖片
論文地址:https://arxiv.org/abs/2303.14389
GitHub地址:https://github.com/sail-sg/MDT
近日,Masked Diffusion Transformer V2再次刷新SoTA, 相比DiT的訓練速度提升10倍以上,并實現了ImageNet benchmark 上 1.58的FID score。
最新版本的論文和代碼均已開源。
背景
盡管以DiT 為代表的擴散模型在圖像生成領域取得了顯著的成功,但研究者發現擴散模型往往難以高效地學習圖像中物體各部分之間的語義關系,這一局限性導致了訓練過程的低收斂效率。
圖片
例如上圖所示,DiT在第50k次訓練步驟時已經學會生成狗的毛發紋理,然后在第200k次訓練步驟時才學會生成狗的一只眼睛和嘴巴,但是卻漏生成了另一只眼睛。
即使在第300k次訓練步驟時,DiT生成的狗的兩只耳朵的相對位置也不是非常準確。
這一訓練學習過程揭示了擴散模型未能高效地學習到圖像中物體各部分之間的語義關系,而只是獨立地學習每個物體的語義信息。
研究者推測這一現象的原因是擴散模型通過最小化每個像素的預測損失來學習真實圖像數據的分布,這個過程忽略了圖像中物體各部分之間的語義相對關系,因此導致模型的收斂速度緩慢。
方法:Masked Diffusion Transformer
受到上述觀察的啟發,研究者提出了Masked Diffusion Transformer (MDT) 提高擴散模型的訓練效率和生成質量。
MDT提出了一種針對Diffusion Transformer 設計的mask modeling表征學習策略,以顯式地增強Diffusion Transformer對上下文語義信息的學習能力,并增強圖像中物體之間語義信息的關聯學習。
圖片
如上圖所示,MDT在保持擴散訓練過程的同時引入mask modeling學習策略。通過mask部分加噪聲的圖像token,MDT利用一個非對稱Diffusion Transformer (Asymmetric Diffusion Transformer) 架構從未被mask的加噪聲的圖像token預測被mask部分的圖像token,從而同時實現mask modeling 和擴散訓練過程。
在推理過程中,MDT仍保持標準的擴散生成過程。MDT的設計有助于Diffusion Transformer同時具有mask modeling表征學習帶來的語義信息表達能力和擴散模型對圖像細節的生成能力。
具體而言,MDT通過VAE encoder將圖片映射到latent空間,并在latent空間中進行處理以節省計算成本。
在訓練過程中,MDT首先mask掉部分加噪聲后的圖像token,并將剩余的token送入Asymmetric Diffusion Transformer來預測去噪聲后的全部圖像token。
Asymmetric Diffusion Transformer架構
圖片
如上圖所示,Asymmetric Diffusion Transformer架構包含encoder、side-interpolater(輔助插值器)和decoder。
圖片
在訓練過程中,Encoder只處理未被mask的token;而在推理過程中,由于沒有mask步驟,它會處理所有token。
因此,為了保證在訓練或推理階段,decoder始終能處理所有的token,研究者們提出了一個方案:在訓練過程中,通過一個由DiT block組成的輔助插值器(如上圖所示),從encoder的輸出中插值預測出被mask的token,并在推理階段將其移除因而不增加任何推理開銷。
MDT的encoder和decoder在標準的DiT block中插入全局和局部位置編碼信息以幫助預測mask部分的token。
Asymmetric Diffusion Transformer V2
圖片
如上圖所示,MDTv2通過引入了一個針對Masked Diffusion過程設計的更為高效的宏觀網絡結構,進一步優化了diffusion和mask modeling的學習過程。
這包括在encoder中融合了U-Net式的long-shortcut,在decoder中集成了dense input-shortcut。
其中,dense input-shortcut將添加噪后的被mask的token送入decoder,保留了被mask的token對應的噪聲信息,從而有助于diffusion過程的訓練。
此外,MDT還引入了包括采用更快的Adan優化器、time-step相關的損失權重,以及擴大掩碼比率等更優的訓練策略來進一步加速Masked Diffusion模型的訓練過程。
實驗結果
ImageNet 256基準生成質量比較
圖片
上表比較了不同模型尺寸下MDT與DiT在ImageNet 256基準下的性能對比。
顯而易見,MDT在所有模型規模上都以較少的訓練成本實現了更高的FID分數。
MDT的參數和推理成本與DiT基本一致,因為正如前文所介紹的,MDT推理過程中仍保持與DiT一致的標準的diffusion過程。
對于最大的XL模型,經過400k步驟訓練的MDTv2-XL/2,顯著超過了經過7000k步驟訓練的DiT-XL/2,FID分數提高了1.92。在這一setting下,結果表明了MDT相對DiT有約18倍的訓練加速。
對于小型模型,MDTv2-S/2 仍然以顯著更少的訓練步驟實現了相比DiT-S/2顯著更好的性能。例如同樣訓練400k步驟,MDTv2以39.50的FID指標大幅領先DiT 68.40的FID指標。
更重要的是,這一結果也超過更大模型DiT-B/2在400k訓練步驟下的性能(39.50 vs 43.47)。
ImageNet 256基準CFG生成質量比較
圖片
我們還在上表中比較了MDT與現有方法在classifier-free guidance下的圖像生成性能。
MDT以1.79的FID分數超越了以前的SOTA DiT和其他方法。MDTv2進一步提升了性能,以更少的訓練步驟將圖像生成的SOTA FID得分推至新低,達到1.58。
與DiT類似,我們在訓練過程中沒有觀察到模型的FID分數在繼續訓練時出現飽和現象。
MDT在PaperWithCode的leaderboard上刷新SoTA
收斂速度比較
圖片
上圖比較了ImageNet 256基準下,8×A100 GPU上DiT-S/2基線、MDT-S/2和MDTv2-S/2在不同訓練步驟/訓練時間下的FID性能。
得益于更優秀的上下文學習能力,MDT在性能和生成速度上均超越了DiT。MDTv2的訓練收斂速度相比DiT提升10倍以上。
MDT在訓練步驟和訓練時間方面大相比DiT約3倍的速度提升。MDTv2進一步將訓練速度相比于MDT提高了大約5倍。
例如,MDTv2-S/2僅需13小時(15k步驟)就展示出比需要大約100小時(1500k步驟)訓練的DiT-S/2更好的性能,這揭示了上下文表征學習對于擴散模型更快的生成學習至關重要。
總結&討論
MDT通過在擴散訓練過程中引入類似于MAE的mask modeling表征學習方案,能夠利用圖像物體的上下文信息重建不完整輸入圖像的完整信息,從而學習圖像中語義部分之間的關聯關系,進而提升圖像生成的質量和學習速度。
研究者認為,通過視覺表征學習增強對物理世界的語義理解,能夠提升生成模型對物理世界的模擬效果。這正與Sora期待的通過生成模型構建物理世界模擬器的理念不謀而合。希望該工作能夠激發更多關于統一表征學習和生成學習的工作。
參考資料:
https://arxiv.org/abs/2303.14389