謝賽寧新作:表征學習有多重要?一個操作刷新SOTA,DiT訓練速度暴漲18倍
擴散模型如何突破瓶頸?成本高又難訓練的DiT/SiT模型如何提升效率?
對于這個問題,紐約大學謝賽寧團隊最近發表的一篇論文找到了一個全新的切入點:提升表征(representation)的質量。
論文的核心或許就可以用一句話概括:「表征很重要!」
用謝賽寧的話來說,即使只是想讓生成模型重建出好看的圖像,仍然需要先學習強大的表征,然后再去渲染高頻的、使圖像看起來更美觀的細節。
這個觀點,Yann LeCun之前也多次強調過。
有網友還在線幫謝賽寧想標題:你這篇論文不如就叫「Representation is all you need」(手動狗頭)
由于觀點一致,這篇研究也獲得了同在紐約大學的Yann LeCun的轉發。
當使用自監督學習訓練視覺編碼器時,我們知道一個事實,使用具有重建損失(reconstruction loss)的解碼器的效果遠遠不如具有特征預測損失(feature prediction loss)和崩潰預防機制的聯合嵌入架構。
這篇來自紐約大學謝賽寧團隊的論文表明,即使只對生成像素感興趣(例如,使用擴散Transformer生成漂亮的圖片),包含特征預測損失也是值得的,以便解碼器的內部表示可以基于預訓練的視覺編碼器(例如 DINOv2)進行特征預測。
REPA的核心思想非常簡單,就是讓擴散模型中的表征與外部更強大的視覺表征進行對齊,但提升效果非常顯著,頗有「他山之石,可以攻玉」的意味。
僅僅是在損失函數添加一項相似度最大化,就能將SiT/DiT的訓練速度提升將近18倍,還刷新了模型的SOTA性能,在ImageNet 256x256上實現了最先進的FID=1.42。
謝賽寧表示,剛看到實驗結果時,他自己也被震驚到了,因為感覺并沒有發明什么全新的東西,而只是意識到了,我們幾乎完全不理解擴散模型和SSL方法學習到的表示。
論文簡介
論文地址:https://arxiv.org/abs/2410.06940
項目地址:https://sihyun.me/REPA/
在生成高維的視覺數據方面,基于去噪方法(如擴散模型)或基于流的生成模型,已經成為了一種可擴展的途徑,并在有挑戰性的的零樣本文生圖/文生視頻任務上取得了非常成功的結果。
最近的研究表明,生成擴散模型中的去噪過程可以在模型內部的隱藏狀態中引入有意義的表示,但這些表示的質量目前仍落后于自監督學習方法,例如DINOv2。
作者認為,訓練大規模擴散模型的一個主要瓶頸,就在于無法有效學習到高質量的內部表示。
如果能夠結合高質量的外部視覺表示,而不是僅僅依靠擴散模型來獨立學習,就可以使訓練過程變得更容易。
為了實現這一點,論文基于經典的擴散Transformer架構,引入了一種簡單的正則化方法REPA(REPresentation Alignment)。
簡單來說,就是將去噪網絡中從噪聲輸入 得到的隱藏狀態??的投影,與外部自監督預訓練的視覺編碼器從干凈圖像??獲得的視覺表示??*進行對齊。
這樣一個非常直給的策略,卻獲得了驚人的結果:應用于流行的SiT或DiT時,模型的訓練效率和生成質量都得到了顯著提高。
具體來說,REPA可以將SiT的訓練速度加快17.5×以上,以不到40萬步的訓練量匹配有700萬步訓練的SiT-XL模型的性能,同時實現了FID=1.42的SOTA結果。
REPA:使用表征對齊的正則化
統一視角的擴散模型+流模型
由于論文希望同時優化基于流的模型SiT和基于去噪的擴散模型DiT,因此首先從統一的隨機插值視角,對這兩種模型進行簡要的回顧。
考慮在t∈[0,T]的連續時間步中,對數據??*~p(??)使用高斯分布ε~??(0,??)添加隨機噪音:
其中,αt和σt分別表示t的遞減和遞增函數。在公式(1)給定的過程中,存在一個帶有速度場(velocity field)的概率流常微分方程:
其中t步時的分布就等于邊際概率pt(??)。
速度??(??,t)可以表示為如下兩個條件期望之和:
這個值可以通過最小化如下訓練目標得到近似值??θ(??,t):
同時,還存在一個反向的隨機微分方程(SDE),帶有擴散系數wt,其中的邊際概率pt(??)與公式(2)相符:
其中,??(??t,t)是一個條件期望值,定義為:
對任意t>0,都可以通過速度??(??,t)計算出??(??,t)的值:
這表明,數據??t也可以通過求解公式(5)的SDE來以另一種方式生成。
以上定義對類似的擴散模型變體,例如DDPM,同樣適用,只是需要將連續的時間步離散化。
方法概述
令p(??)為數據??∈??的未知目標分布,我們的訓練目標就是通過模型對數據的學習得到p(??)的近似。
為了降低計算成本,最近流行的「潛在擴散」方法(latent diffusion)提出學習潛在變量??=E(??)的分布p(??),其中E表示來自預訓練自編碼器(例如KL-VAE)中的編碼部分。
要學習到分布p(??),就需要訓練擴散模型??θ(??t,t),訓練目標是進行速度預測,具體方法如上一節所述。
放在自監督表示學習的背景中,可以將擴散模型看成編碼器fθ:?????和解碼器gθ:?????的組合,其中編碼器負責隱式地學習到表示??t以重建目標??t。
然而,作者提出,用于生成的大型擴散模型并不擅長表征學習,因此REPA引入了外部的語義豐富的表示,從而顯著提升生成性能。
REPA方法概述
模型觀察
擴散模型是否真的不擅長表征學習?這需要更進一步地觀察模型才能確定,為此,研究人員測量并比對了diffusion transformer和當前的SOTA自監督模型DINOv2之間的表征差距,包括語義差距和特征對齊兩種角度。
語義差距
從圖2a可知,預訓練SiT的隱藏層表示在第20層達到最佳狀態,這與之前的研究結果相符,但仍遠遠落后于DINOv2。
特征對齊
如圖2b和2c所示,使用CKNNA值測量SiT和DINOv2之間的表征對齊程度后發現,SiT的對齊效果會隨著模型增大和訓練迭代步數增加而逐漸改善,但即使增加到7M次迭代,和DINOv2之間的對齊程度仍然不足。
事實上,這種差距不僅在SiT中存在,根據附錄C.2的實驗結果,DiT等其他基于去噪的生成式Transformer模型也存在類似的問題。
縮小表征差距
那么,REPA方法究竟如何縮小這種表征差距,讓diffusion transformer在噪聲輸入中也能學到有用的語義特征?
定義N,D分別表示patch數量預訓練編碼器f的嵌入維度,編碼器輸入為無噪聲的圖像??*,輸出為??*=f(??*)∈?N×D。
Diffusion transformer將編碼器輸出??t=fθ(??t)通過一個可訓練的投影頭hφ(MLP)投影為hφ(??t)∈?N×D。
之后,REPA負責將hφ(??t)與??*進行對齊,通過最大化兩者間的patch間相似度:
在實際實現中,將這一項添加到公式(4)定義的基于擴散的訓練目標中,就得到總體的訓練目標:
其中超參數λ>0用于控制模型在去噪目標和表征對齊間的權衡。
從圖3結果可知,REPA減少了表示中的語義差距。
有趣的是,使用REPA后,僅對齊前幾個Transformer塊就能實現足夠程度的表示對齊,從而讓diffusion transformer的靠后層專注于捕獲高頻細節,從而進一步提高生成性能。
實驗結果
為了驗證REPA方法的有效性,實驗在兩種流行的擴散模型訓練目標(即??velocity)上進行了實驗,包括DiT中改進后的DDPM和SiT中的線性隨機插值,但實際中也同樣可以考慮其他的訓練目標。
所用模型默認嚴格遵循SiT和DiT的原始結構(除非有特別說明),包括B/2、L/2、XL/2三種參數設置,如表1所示。
以下實驗旨在回答3個問題:
- REPA能否顯著提升diffusion transformer的訓練?
- REPA在模型規模和表征質量方面是否具有可擴展性?
- 擴散模型的表征能否和多種視覺表征進行對齊?
REPA提升視覺縮放
首先比較兩個SiT-XL/2模型在前400K次迭代期間生成的圖像,它們共享相同的噪聲、采樣器和采樣步數,但其中使用REPA訓練的模型顯示出更好的進展。
REPA在各個方面都展現出了強大的可擴展性
研究人員還改變了預訓練編碼器和Diffusion Transformer的模型大小來檢驗REPA的可擴展性。
圖5a結果表明,與更好的視覺表示相結合可以改善生成效果和線性探測的結果。
此外,如圖5b和c所示,增加模型大小可以在生成和線性評估方面帶來更快的收益,也就是說,模型規模越大,REPA的加速效果越明顯,表現出了強大的可擴展性。
REPA顯著提高訓練效率和生成質量
最后,論文比較了普通DiT或SiT模型在訓練中使用REPA前后的FID值。
在沒有指導的情況下,REPA在400K次迭代時實現了FID=7.9,優于普通模型在7M次迭代后的性能。
此外,使用無分類器引導時,帶有REPA的SiT-XL/2的性能優于SOTA性能(FID=1.42),同時迭代次數減少了7倍。
作者介紹
Sihyun Yu
本文一作Sihyun Yu是KAIST(韓國科學技術院)人工智能專業最后一年的博士生,此前他同樣在KAIST獲得了數學和計算機科學的雙專業學士學位。
他的研究主要集中在減少大型生成模型訓練(和采樣)的內存和計算負擔,其中,對大規模且高效的視頻生成特別感興趣;博士期間,他還曾在英偉達和谷歌研究院擔任實習生。