4倍內存效率,生成和超分雙SOTA!清華&智譜AI發布最新Inf-DiT模型
文章鏈接:https://arxiv.org/pdf/2405.04312
github 鏈接:https://github.com/THUDM/Inf-DiT
擴散模型在近年來的圖像生成中表現出了顯著的性能。然而,由于生成超高分辨率圖像(如 4096 × 4096)時內存需求呈二次方增加,生成圖像的分辨率通常限制在 1024 × 1024。
本文提出了一種單向塊注意力機制,可以在推理過程中自適應地調整內存開銷并處理全局依賴關系。基于這個模塊,本文采用 DiT 結構進行上采樣,并開發了一種能夠對各種形狀和分辨率的圖像進行上采樣的無限超分辨率模型。綜合實驗表明,本文的模型在機器和人類評估中都達到了生成超高分辨率圖像的最新技術水平。與常用的 UNet 結構相比,本文的模型在生成 4096 × 4096 圖像時可以節省超過 5 倍的內存。
介紹
近年來,擴散模型取得了迅速進展,顯著推動了圖像生成和編輯領域的發展。盡管有這些進展,但仍然存在一個關鍵限制:現有圖像擴散模型生成的圖像分辨率通常限制在 1024×1024 像素或更低,這在生成超高分辨率圖像時構成了重大挑戰。而超高分辨率圖像在各種實際應用中是不可或缺的,包括復雜的設計項目、廣告、海報和壁紙的制作等。
一種常用的生成高分辨率圖像的方法是級聯生成,即首先生成低分辨率圖像,然后應用多個上采樣模型逐步提高圖像的分辨率。這種方法將高分辨率圖像的生成分解為多個任務。基于前一階段生成的結果,后一階段的模型只需進行局部生成。在級聯結構的基礎上,DALL-E2 和 Imagen 都能有效生成分辨率為 1024 的圖像。
對于上采樣到更高分辨率圖像的最大挑戰是顯著的 GPU 內存需求。例如,如果在圖像推理中使用廣泛采用的 U-Net 架構(如 SDXL,見下圖 2),觀察到隨著分辨率的增加,內存消耗急劇上升。具體而言,生成一個 4096×4096 分辨率的圖像(包含超過 1600 萬個像素)需要超過 80GB 的內存,這超出了標準的 RTX 4090 或 A100 顯卡的容量。此外,高分辨率圖像生成模型的訓練過程加劇了這些需求,因為它需要額外的內存來存儲梯度、優化器狀態等。
LDM 通過利用變分自編碼器(VAE)來壓縮圖像并在較小的潛在空間中生成圖像,從而減少了內存消耗。然而,文中也強調,過高的壓縮比會顯著降低生成質量,嚴重限制了內存消耗的減少。
基于這一算法,本文優化了擴散 Transformer(DiT),并訓練了一個名為 Inf-DiT 的模型,該模型能夠對不同分辨率和形狀的圖像進行上采樣。此外,設計了幾種技術,包括提供全局圖像 embedding 以增強全局語義一致性,并提供 zero-shot 文本控制能力,以及通過交叉注意力機制提供所有相鄰的低分辨率(LR)塊以進一步增強局部一致性。評估結果表明,Inf-DiT 在機器和人類評估中均顯著優于其他高分辨率生成模型。
主要貢獻如下:
- 基于這些方法,訓練了一個圖像上采樣擴散模型 Inf-DiT,這是一種 700M 的模型,能夠對不同分辨率和形狀的圖像進行上采樣。Inf-DiT 在機器評估(HPDv2 和 DIV2K 數據集)和人類評估中均達到了最新技術水平。
- 設計了多種技術來進一步增強局部和全局一致性,并提供靈活的文本控制的 zero-shot 能力。
方法
單向塊注意力 (UniBA)
生成超高分辨率圖像的關鍵障礙是內存限制
隨著圖像分辨率的增加,網絡中對應的隱藏狀態的大小呈二次方增長。例如,僅一層中形狀為 2048 × 2048 × 1280 的單個隱藏狀態就需要 20GB 的內存,使得生成非常大的圖像變得異常艱難。如何避免在內存中存儲整個圖像的隱藏狀態成為關鍵問題。
- 塊之間的生成依賴關系是單向的,并且可以形成一個有向無環圖(DAG)。
- 每個塊對其他塊只有少量的直接(一級)依賴關系,因為塊及其直接依賴塊的隱藏狀態需要同時保存在內存中。
此外,為了確保整個圖像的一致性,還需要確保每個塊具有足夠大的感受野,以處理長程依賴關系。
根據上述條件和分析,本文選擇了一種高效的實現方式,即下圖 3 所示的單向塊注意力(UniBA)。對于每一層,每個塊直接依賴于三個一階相鄰塊:頂部的塊、左側的塊和左上角的塊。例如,如果本文采用了 Inf-DiT 的基礎架構 Diffusion Transformer(DiT)架構,則塊之間的依賴關系是注意力操作,其中每個塊的查詢向量與其左上角和本身的四個塊的鍵值向量進行交互,如下圖 3 所示。
形式上,Transformer 中的 UniBA 過程可以表示為:
需要注意的是,盡管每個塊在每層中只關注少量相鄰塊,但隨著特征層層傳播,塊可以間接與遠處的塊交互,從而捕捉長短距離關系。本文的設計與自然語言模型 Transformer-XL 具有相似的精神,可以看作是本文的一維情況的特例。
使用 O(N)內存消耗的推理過程
盡管本文的方法可以順序生成每個塊,但它不同于自回歸生成模型,在自回歸生成模型中,下一個塊依賴于前一個塊的最終輸出。在本文的模型中,只要它們的依賴塊的集合已經生成,就可以并行生成任意數量的塊。基于這一特性,本文實現了一個簡單但有效的推理過程。如上面圖 3 所示,本文一次生成 n×n 個塊,從左上到右下。生成一組塊后,本文丟棄不再使用的隱藏狀態(即 KV 緩存),并將新生成的 KV 緩存附加到內存中。
在實際操作中,盡管對于不同的 n,圖像生成的總 FLOPs 保持不變,但由于操作初始化時間和內存分配時間等開銷,當 n 增加時,生成時間會減少。因此,在內存限制允許的情況下,選擇最大的 n 是最優的。
基本模型架構
下圖 4 概述了本文模型 Inf-DiT 的架構。該模型使用了類似 DiT 的主干結構,DiT 將視覺 Transformer(ViT)應用于擴散模型,并證明了其有效性和可擴展性。與基于卷積的架構(如 UNet)相比,DiT 僅使用注意力作為塊之間的交互機制,這使得單向塊注意力的實現變得方便。為了適應單向塊注意力并增強上采樣性能,本文進行了如下的幾項修改和優化。
模型輸入
Inf-DiT 首先將輸入圖像劃分為多個不重疊的塊,然后將這些塊進一步劃分為邊長等于 patch 大小的 patch。與 DiT 不同,考慮到壓縮損失(如顏色偏移和細節損失),Inf-DiT 在 RGB 像素空間中進行 patch 劃分,而不是在潛在空間中。在超分辨率因子為 f 的情況下,Inf-DiT 首先將低分辨率 RGB 圖像條件上采樣 f 倍,然后在特征維度上將其與擴散的噪聲輸入連接起來,再輸入到模型中。
位置編碼
與可以通過卷積操作感知位置關系的基于 UNet 的擴散模型不同,Transformer 中的所有操作(包括自注意力和前饋神經網絡)都是置換不變函數。因此,基于 Transformer 的模型需要輔助輸入顯式位置信息以學習 patch 之間的關系。正如最近在大型語言模型中的研究所示,相對位置編碼在捕捉單詞位置相關性方面比絕對位置編碼更有效,本文參考了旋轉位置編碼(RoPE)的設計,該設計在長上下文生成中表現良好,并將其適配為二維形式用于圖像生成。具體來說,本文將隱藏狀態的通道分成兩半,一半用于編碼 x 坐標,另一半用于編碼 y 坐標,并在這兩個部分中應用 RoPE。
本文創建了一個足夠大的 RoPE 位置編碼表,以確保在生成過程中滿足需求。為了確保模型在訓練期間能看到位置編碼表的所有部分,本文采用了隨機起始點:對于每個訓練圖像,本文隨機分配一個位置(x,y)作為圖像的左上角,而不是默認的(0,0)。
全局和局部一致性
使用 CLIP 圖像 embedding 實現全局一致性
低分辨率(LR)圖像中的全局語義信息(如藝術風格和物體材質)在上采樣過程中起著至關重要的作用。然而,與文本生成圖像模型相比,上采樣模型有一個額外的任務:理解和分析低分辨率圖像的語義信息,這大大增加了模型的負擔。這在沒有文本數據進行訓練時尤其具有挑戰性,因為高分辨率圖像很少有高質量的配對文本,使得這些方面對模型來說很難處理。
使用鄰近 LR 交叉注意力實現局部一致性盡管將 LR 圖像與噪聲輸入連接起來已經為模型學習 LR 和 HR 圖像之間的局部對應關系提供了良好的歸納偏置,但仍然可能存在連續性問題。原因在于,對于給定的 LR 塊,有多種上采樣的可能性,這需要結合多個鄰近 LR 塊進行分析以選擇一個解決方案。假設上采樣僅基于其左側和上方的 LR 塊進行,它可能會選擇一個與右側和下方的 LR 塊沖突的 HR 生成方案。那么,當上采樣右側的 LR 塊時,如果模型認為符合其對應的 LR 塊比與左側塊連續更重要,則會生成與先前塊不連續的 HR 塊。一個簡單的解決方案是將整個 LR 圖像輸入到每個塊中,但當 LR 圖像的分辨率也很高時,這樣做成本太高。
為了解決這個問題,本文引入了鄰近 LR 交叉注意力。在 Transformer 的第一層,每個塊對周圍的 3×3 LR 塊進行交叉注意力,以捕捉附近的 LR 信息。本文的實驗表明,這種方法顯著降低了生成不連續圖像的概率。值得注意的是,這個操作不會改變本文的推理過程,因為在生成之前整個 LR 圖像是已知的。
本文進一步設計了包括無類別指導的連續性、基于 LR 的噪聲初始化、QK 規范化等技術。
實驗
在本節中,本文首先介紹 Inf-DiT 的詳細訓練過程,然后通過機器和人類評價全面評估 Inf-DiT 的性能。結果表明,Inf-DiT 在超高分辨率圖像生成和上采樣任務中均優于所有基線模型。最后,本文進行消融研究,以驗證本文設計的有效性。
訓練細節
數據集
數據集由分辨率高于 1024×1024 且美學評分高于 5 的 LAION-5B 子集和來自互聯網的 10 萬張高分辨率壁紙組成。與之前的工作[20,23,30]相同,本文在訓練期間使用固定大小的 512×512 分辨率的圖像裁剪。由于上采樣可以僅使用局部信息進行,因此在推理時可以直接在更高分辨率下進行,這對大多數生成模型來說并不容易。
數據處理
由于擴散模型生成的圖像通常包含殘留噪聲和各種細節不準確性,因此增強上采樣模型的魯棒性以解決這些問題變得至關重要。本文采用類似于 Real-ESRGAN 的方法,對訓練數據中的低分辨率輸入圖像進行各種降質處理。
在處理分辨率高于 512 的訓練圖像時,有兩種替代方法:直接執行隨機裁剪,或將較短的一側調整為 512 后再進行隨機裁剪。直接裁剪方法保留了高分辨率圖像中的高頻特征,而調整后裁剪方法則避免了頻繁裁剪出單色背景區域,這對模型的收斂性不利。因此,在實踐中,本文隨機選擇這兩種處理方法中的一種來裁剪訓練圖像。
訓練設置
在訓練期間,本文設置塊大小為 128,patch 大小為 4,這意味著每個訓練圖像被分成 4×4 個塊,每個塊有 32×32 個 patch。本文采用 EDM 框架進行訓練,并將上采樣因子設置為 4 倍。由于上采樣任務更關注圖像的高頻細節,本文調整了訓練噪聲分布的均值和標準差為-1.0 和 1.4。為了解決訓練期間的溢出問題,本文使用了具有更大數值范圍的 BF16 格式。本文的 CLIP 模型是一個在 Datacomp 數據集上預訓練的 ViT-L/16。由于 CLIP 只能處理 224×224 分辨率的圖像,本文首先將低分辨率圖像調整為 224×224,然后將其輸入到 CLIP 中。
機器評價
在這部分中,本文對 Inf-DiT 在超高分辨率圖像生成任務上與最先進方法進行定量比較。基線包括兩大類高分辨率生成:1. 直接高分辨率圖像生成,包括 SDXL 的直接推理、MultiDiffusion、ScaleCrafte 等;2. 基于超分辨率技術的高分辨率圖像生成,包括 BSRGAN、DemoFusion 等。本文采用 FID(Fréchet Inception Distance)來評估超高分辨率生成的質量,這在圖像生成任務中廣泛用于評估圖像的感知質量。為了進一步驗證本文模型的超分辨率能力,本文還將其與經典的超分辨率模型在典型超分辨率任務上進行了基準測試。
超高分辨率生成本文使用 HPDv2 的測試集進行評估。它包含 3200 個提示,并分為四個類別:“動畫”、“概念藝術”、“繪畫”和“照片”。這允許對模型在各個領域和風格中的生成能力進行全面評估。本文在兩個分辨率上進行測試:2048x2048 和 4096x4096。對于基于超分辨率的模型,本文首先使用 SDXL 生成 1024x1024 分辨率的圖像,然后在沒有文本的情況下對其進行上采樣。本文使用 BSRGAN 的 2× 和 4× 版本分別進行 2048x2048 和 4096x4096 的生成。盡管 Inf-DiT 是在 4× 上采樣的設置下進行訓練的,但本文發現它在較低的上采樣倍數下也能很好地泛化。因此,對于 2048x2048 的生成,本文直接將 LR 圖像從 1024x1024 調整為 2048x2048,并與噪聲輸入連接起來。本文從 LAION-5B 中隨機選擇了 3200 張 2048x2048 和 4096x4096 的圖像作為真實圖像的分布。
如下表 2 所示,本文的模型在所有指標上均達到了最先進水平。這表明,作為一個超分辨率模型,本文的模型不僅在任意尺度上表現出色,而且在最大限度地保留全局和詳細信息的同時,還能恢復與原始圖像非常接近的結果。
人類評價
為了進一步評估 Inf-DiT 并更準確地從人類視角反映其生成質量,本文進行了人類評價。比較設置與上節中相同,不過本文排除了 MultiDiffusion 和 Direct Inference 因其非競爭性的結果。對于每個類別,本文隨機選擇了十組比較集,每組集合包含了四個模型的輸出,共計 40 組形成了人類評價數據集。為了保證公平性,在每個比較集中本文對模型輸出的順序進行了隨機化。人類評估者被要求根據三個標準評估模型:細節真實性、全局連貫性和與原始低分辨率輸入的一致性。每位評估者平均收到 20 組圖像。在每個集合中,評估者需要根據三個標準對由四個模型生成的圖像進行從高到低的排名。
本文最終收集了 3600 組比較。如下圖 7 所示,本文的模型在所有三個標準中均優于其他三種方法。值得注意的是,其他三種模型中的每一種在至少一個評估標準上排名相對較低,而 Inf-DiT 在所有三個標準上的得分都最高:細節真實性、全局連貫性和與低分辨率輸入的一致性。這表明本文的模型是唯一能夠在高分辨率生成和超分辨率任務中同時表現出色的模型。
迭代上采樣
由于本文的模型可以對任意分辨率的圖像進行上采樣,測試模型是否能夠迭代上采樣自身生成的圖像是一個自然的想法。在這項研究中,本文在一張 322 分辨率的圖像上進行實驗,通過三次迭代上采樣,將其生成一張 2048x2048 分辨率的圖像,即 64 倍放大。下圖 8 展示了這個過程的兩個案例。在第一個案例中,模型成功地在三個階段的上采樣后生成了一張高分辨率圖像。它在不同分辨率的上采樣中生成了不同頻率的細節:臉部輪廓、眼球形狀和個別睫毛。然而,模型很難糾正在早期階段生成的不準確性,導致錯誤的積累。第二個樣本展示了這個問題。本文將這個問題留給未來的工作。
消融研究
相關工作
擴散圖像生成
擴散模型已經成為圖像生成領域的焦點,近年來取得了一系列突破性進展。最初于 2015 年引入,并通過諸如 DDPM 和 DDIM 等工作進一步發展,這些模型利用隨機擴散過程,概念化為馬爾可夫鏈,將簡單的先驗分布(如高斯噪聲)轉化為復雜的數據分布。這一方法在生成的圖像質量和多樣性方面取得了令人印象深刻的成果。
近期的增強顯著提升了擴散模型的生成能力。CDM 創建了一個級聯生成 pipeline,其中包括多階段的超分辨率模型,可應用于大型預訓練模型。引入潛在擴散模型(LDMs)代表了一個重要的擴展,它結合了潛在空間,提升了效率和可擴展性。除此之外,網絡架構的優化也取得了顯著進展。擴散 Transformer(DiT)的出現取代了 U-Net,使用 ViT 進行噪聲預測。
圖像超分辨率
這里 D 和 F 分別表示退化過程和超分辨率模型。δ和θ 代表參數。
近年來,盲目 SR 一直是主要關注的焦點:其中退化過程是未知的且可學習的。這一視角導致了有效的建模技術的發展,例如 BSRGAN 和 Real-ESRGAN。最近,基于擴散的 SR 方法取得了令人興奮的結果。這些工作專注于對預訓練的文本到圖像擴散模型進行微調,以利用其優秀的生成能力。具體來說,DiffBir 在預訓練的穩定擴散模型上使用了 ControlNet,而 PASD 通過執行像素感知的交叉注意力來增強它。這兩種方法在固定分辨率超分辨率方面取得了相當大的成功,但不能直接用于更高的分辨率。
超高分辨率圖像上采樣器
目前,圖像生成方法在生成超高分辨率圖像方面存在著內存限制和訓練效率問題。在這種情況下,MultiDiffusion 和 Mixture of Diffusers 將多個擴散生成過程綁定在一起,通過將圖像劃分為重疊的塊,分別處理每個塊,然后將它們拼接在一起,旨在保持塊之間的連續性。然而,由于它們僅使用局部加權平均進行聚合,導致了交互效率低下,使得很難確保圖像的全局一致性。
鑒于這一問題,DemoFusion 和 ScaleCrafter 采用了擴張策略,包括擴張采樣和擴張卷積核,旨在獲取更多的全局信息。這些方法確實在全局語義水平上取得了改進,而無需額外的訓練。然而,訓練和生成之間的巨大差異導致這些方法很容易產生不合邏輯的圖像。
Inf-DiT 能夠對任何生成模型生成的圖像執行上采樣,在這里展示了更多的情況。
結論
在這項工作中,本文觀察到生成超高分辨率圖像的主要障礙是模型隱藏狀態占用了大量內存。基于此,本文提出了一種單向塊注意力機制(UniBA),它可以通過在塊之間進行批量生成來降低空間復雜度。利用 UniBA,本文訓練了 Inf-DiT,這是一種 4 倍內存效率的圖像上采樣器,在生成和超分辨率任務中均取得了最先進的性能。
本文轉自 AI生成未來 ,作者:Zhuoyi Yang等
