Make U-Nets Great Again!北大&華為提出擴散架構U-DiT,六分之一算力即可超越DiT
Sora 的發(fā)布讓廣大研究者及開發(fā)者深刻認識到基于 Transformer 架構擴散模型的巨大潛力。作為這一類的代表性工作,DiT 模型拋棄了傳統(tǒng)的 U-Net 擴散架構,轉(zhuǎn)而使用直筒型去噪模型。鑒于直筒型 DiT 在隱空間生成任務上效果出眾,后續(xù)的一些工作如 PixArt、SD3 等等也都不約而同地使用了直筒型架構。
然而令人感到不解的是,U-Net 結構是之前最常用的擴散架構,在圖像空間和隱空間的生成效果均表現(xiàn)不俗;可以說 U-Net 的 inductive bias 在擴散任務上已被廣泛證實是有效的。因此,北大和華為的研究者們產(chǎn)生了一個疑問:能否重新拾起 U-Net,將 U-Net 架構和 Transformer 有機結合,使擴散模型效果更上一層樓?帶著這個問題,他們提出了基于 U-Net 的 DiT 架構 U-DiT。
- 論文標題:U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers
- 論文地址:https://arxiv.org/pdf/2405.02730
- GitHub 地址:https://github.com/YuchuanTian/U-DiT
從一個小實驗談開去
首先,研究者開展了一個小實驗,在實驗中嘗試著將 U-Net 和 DiT 模塊簡單結合。然而,如表 1 所示,在相似的算力比較下,U-Net 的 DiT(DiT-UNet)僅僅比原始的 DiT 有略微的提升。
在圖 3 中,作者們展示了從原始的直筒 DiT 模型一步步演化到 U-DiT 模型的過程。
根據(jù)先前的工作,在擴散中 U-Net 的主干結構特征圖主要為低頻信號。由于全局自注意力運算機制需要消耗大量算力,在 U-Net 的主干自注意力架構中可能存在冗余。這時作者注意到,簡單的下采樣可以自然地濾除噪聲較多的高頻,強調(diào)信息充沛的低頻。既然如此,是否可以通過下采樣來消除對特征圖自注意力中的冗余?
Token 下采樣后的自注意力
由此,作者提出了下采樣自注意力機制。在自注意力之前,首先需將特征圖進行 2 倍下采樣。為避免重要信息的損失,生成了四個維度完全相同的下采樣圖,以確保下采樣前后的特征總維度相同。隨后,在四個特征圖上使用共用的 QKV 映射,并分別獨立進行自注意力運算。最后,將四個 2 倍下采樣的特征圖重新融為一個完整特征圖。和傳統(tǒng)的全局自注意力相比,下采樣自注意力可以使得自注意力所需算力降低 3/4。
令人驚訝的是,盡管加入下采樣操作之后能夠顯著模型降低所需算力,但是卻反而能獲得比原來更好的效果(表 1)。
U-DiT:全面超越 DiT
根據(jù)此發(fā)現(xiàn),作者提出了基于下采樣自注意力機制的 U 型擴散模型 U-DiT。對標 DiT 系列模型的算力,作者提出了三個 U-DiT 模型版本(S/B/L)。在完全相同的訓練超參設定下,U-DiT 在 ImageNet 生成任務上取得了令人驚訝的生成效果。其中,U-DiT-L 在 400K 訓練迭代下的表現(xiàn)比直筒型 DiT-XL 模型高約 10 FID,U-DiT-S/B 模型比同級直筒型 DiT 模型高約 30 FID;U-DiT-B 模型只需 DiT-XL/2 六分之一的算力便可達到更好的效果(表 2、圖 1)。
在有條件生成任務(表 3)和大圖(512*512)生成任務(表 5)上,U-DiT 模型相比于 DiT 模型的優(yōu)勢同樣非常明顯。
研究者們還進一步延長了訓練的迭代次數(shù),發(fā)現(xiàn) U-DiT-L 在 600K 迭代時便能優(yōu)于 DiT 在 7M 迭代時的無條件生成效果(表 4、圖 2)。
U-DiT 模型的生成效果非常出眾,在 1M 次迭代下的有條件生成效果已經(jīng)非常真實。
論文已被 NeurIPS 2024 接收,更多內(nèi)容,請參考原論文。