圖像生成模型王牌——Diffusion Transformers系列工作梳理
圖像生成模型是目前業(yè)內(nèi)研究的焦點(diǎn),而目前諸如Sora等前沿生成模型,其所基于的主體架構(gòu)都是Diffusion Transformers(DiT)。Diffusion Transformers(DiT)是論文Scalable Diffusion Models with Transformers(ICCV 2023)中提出的,是擴(kuò)散模型和Transformer的結(jié)合,也是Sora使用的底層生成模型架構(gòu),將Diffusion Transformers從圖像生成擴(kuò)展到了視頻生成。這篇文章給大家總結(jié)了目前主要的幾個(gè)DiT模型結(jié)構(gòu),帶大家梳理DiT系列模型的核心。?
更加完整的多模態(tài)生成模型技術(shù)總結(jié),可以加入我的星球【圓圓的算法筆記】,獲取Sora底層原理解析專欄。
1.DiT
在之前的圖像生成擴(kuò)散模型中,底層的網(wǎng)絡(luò)結(jié)構(gòu)一般都是U-Net。而本文基于Vision Transformer(ViT)中的Transformer圖像分類模型結(jié)構(gòu),替代擴(kuò)散模型中的U-Net,得到DiT模型,實(shí)現(xiàn)了更優(yōu)的生成效果。
在輸入部分,基本采用了和ViT相同的方法。對(duì)輸入的圖像分成多個(gè)patch,并轉(zhuǎn)換成一個(gè)token序列,每個(gè)token拼接上相應(yīng)的position embedding。這個(gè)底層的embedding序列作為后續(xù)DiT模塊的輸入。
在擴(kuò)散模型中,Transformer除了像ViT那樣輸入圖像patch token序列,往往還要輸入一些額外的信息,包括擴(kuò)散模型中當(dāng)前的生成時(shí)間步、文本信息的輸入等,如何將這些信息輸入到DiT中,文中嘗試了幾種方案。最簡(jiǎn)單的方法是將這些額外的embedding直接拼接到原始的序列上。第二種是將外部的embedding單獨(dú)拼接成一個(gè)序列,和原始的圖像patch序列額外做一個(gè)cross attention。第三種方法是修改Transformer中的layer normalization模塊,將其替換成adaptive layer normalization,LN的均值和方差由外部embedding的加和生成。第四種是在第三種的基礎(chǔ)上,引入了基于外部embedding生成的縮放因子,對(duì)multi-head attention的輸出進(jìn)行縮放。
在經(jīng)過多層的DiT模型后,需要將預(yù)測(cè)的噪聲結(jié)果還原出來,這里使用一個(gè)MLP作為Decoder,將DiT生成的結(jié)果映射到噪聲預(yù)測(cè)結(jié)果。
上述就是DiT的整體結(jié)構(gòu),主要還是Vision Transformer。用這個(gè)DiT結(jié)構(gòu),替代擴(kuò)散模型中的去噪模塊,也就是噪聲預(yù)測(cè)網(wǎng)絡(luò),就是DiT模型
從實(shí)驗(yàn)對(duì)比中可以看出,DiT的生成效果是超過基于U-Net等之前的SOTA模型的。
2.U-ViT
U-ViT是另一個(gè)基于ViT的擴(kuò)散模型網(wǎng)絡(luò)。U-ViT也是將擴(kuò)散模型中的噪聲預(yù)測(cè)網(wǎng)絡(luò)替換成Transformer結(jié)構(gòu),并且借鑒了U-Net等傳統(tǒng)CV模型中的殘差網(wǎng)絡(luò)思路,每一層的輸出都會(huì)通過龍skip connection加到更深層的網(wǎng)絡(luò)中。此外,文中對(duì)一些模型結(jié)構(gòu)也進(jìn)行了嘗試,包括殘差網(wǎng)絡(luò)怎么加,是直接拼接到深層+MLP還是add到生成;擴(kuò)散步驟embedding怎么加入到U-ViT中;以及Transformer之后的卷積網(wǎng)絡(luò)怎么加。
3.MDT
MDT發(fā)表于論文Masked diffusion transformer is a strong image synthesizer(ICCV 2023),在DiT的基礎(chǔ)上,引入了mask latent modeling,進(jìn)一步提升了DiT的收斂速度和生成效果。
文中分析發(fā)現(xiàn),DiT在學(xué)習(xí)過程中,并不能很好的學(xué)習(xí)各個(gè)語義單元之間的關(guān)系。為了解決這個(gè)問題,MDT引入了一個(gè)重構(gòu)任務(wù),對(duì)輸入的圖像的部分patch進(jìn)行mask,然后使用一個(gè)Transformer模型在生成過程中,對(duì)這部分被mask掉的patch進(jìn)行還原。在擴(kuò)散模型中,每一層MDT輸入被mask掉一部分的token序列,只根據(jù)這部分序列進(jìn)行噪聲預(yù)測(cè)。同時(shí),使用一個(gè)Transformer網(wǎng)絡(luò)來還原被mask掉的部分。通過這種方式,讓模型在學(xué)習(xí)過程中強(qiáng)行學(xué)習(xí)patch之間的關(guān)系。同時(shí)通過position embedding的引入提升對(duì)mask token的還原能力。
由于在生成階段,decoder在處理token的時(shí)候都是沒有mask的,訓(xùn)練的時(shí)候是mask的,這種不一致會(huì)影響效果。因此文中采用side-interpolater,對(duì)被mask掉的部分使用side-interpolater的預(yù)測(cè)結(jié)果,融合上沒被mask的結(jié)果,保證訓(xùn)練和預(yù)測(cè)階段decoder的輸入都是沒有mask掉的。
4.Diffit
Diffit是英偉達(dá)發(fā)表于論文Diffit: Diffusion vision transformers for image generation(2023)中的一種方法,也是Diffusion Transformer的一個(gè)變體,在模型結(jié)構(gòu)上進(jìn)行了改進(jìn)。整體的結(jié)構(gòu)類似于U-Net和Transformer的結(jié)合,通過增加downsample和upsample實(shí)現(xiàn)層次性的建模。
Diffit在引入擴(kuò)散步驟embedding的時(shí)候,采用了一種Time-dependent Self-Attention的方式,即將步驟embedding直接加入到輸入token序列上,讓self-attention在計(jì)算的過程中就考慮到擴(kuò)散步驟的信息。在模型結(jié)構(gòu)上,采用U-Shape的形式,Encoder部分每一層Transformer后做downsample,來提取不同分辨率下的圖像信息,Decoder部分再逐漸upsample。
本文轉(zhuǎn)載自 ??圓圓的算法筆記??,作者: Fareise
