單GPU搞定高清長(zhǎng)視頻生成,效率×10!引入Mamba機(jī)制突破DiT瓶頸 | 普林斯頓&Meta
視頻生成模型太貴太慢怎么辦?
普林斯頓大學(xué)和Meta聯(lián)合推出的新框架LinGen,以MATE線性復(fù)雜度塊取代傳統(tǒng)自注意力,將視頻生成從像素?cái)?shù)的平方復(fù)雜度壓到線性復(fù)雜度,使單張GPU就能在分鐘級(jí)長(zhǎng)度下生成高質(zhì)量視頻,大幅提高了模型的可擴(kuò)展性和生成效率。
實(shí)驗(yàn)結(jié)果表明,LinGen在視頻質(zhì)量上優(yōu)于DiT(勝率達(dá)75.6%),并且最高可減少15×(11.5×)FLOPs(延遲)。此外,自動(dòng)指標(biāo)和人工評(píng)估均顯示,LinGen-4B在視頻質(zhì)量上與最先進(jìn)模型相當(dāng)(分別以50.5%、52.1%、49.1%的勝率優(yōu)于Gen-3、Luma Labs和Kling)。
方法:線性復(fù)雜度的MATE模塊
LinGen維持Diffusion Transformer(DiT)中的其他結(jié)構(gòu)不變,而將其計(jì)算瓶頸——平方復(fù)雜度的自注意力模塊替換為線性復(fù)雜度的MATE模塊,它由MA分支和TE分支組成。
其中,MA分支包含一個(gè)雙向的Mamba2模塊。
Mamba2作為State Space Model(SSM)的變體,善于處理超長(zhǎng)的token序列,同時(shí)又對(duì)硬件非常友好,可以使用attention的各種硬件加速核,如xformers,F(xiàn)lashAttention等。但是Mamba系列模型在語(yǔ)言任務(wù)上的優(yōu)秀表現(xiàn)難以直接遷移到大型視覺(jué)任務(wù)上,生成的高分辨率視頻往往一致性很差、質(zhì)量不高。
一些特殊的scan方法嘗試解決這一問(wèn)題,如Zigzag scan,Hilbert scan,但它們都要求對(duì)序列做復(fù)雜的順序變換,而這個(gè)操作對(duì)硬件極其不友好。在處理高分辨率、長(zhǎng)視頻時(shí),會(huì)帶來(lái)顯著的額外延遲。
針對(duì)于此,LinGen提出Rotary Major Scan(RMS),相鄰層中四種scan方式交替切換。
以上圖的方式為例,W,H和T分別在展開時(shí)有第一、第二和第三優(yōu)先級(jí),通過(guò)交換展開的優(yōu)先級(jí),就可以實(shí)現(xiàn)不同的scan方式。
相比于已有方法,該方法最大的好處是對(duì)硬件非常友好、可以通過(guò)簡(jiǎn)單的tensor reshaping實(shí)現(xiàn),因此也幾乎沒(méi)有額外開銷,同時(shí)還把scan后原相鄰token的平均距離降到了和已有特殊scan方式相同的水平。
然而,所有這些特殊的scan方式仍然不足以完全解決Mamba的臨近信息丟失問(wèn)題,因?yàn)樵谀P偷娜我庖粚又校粫?huì)有一種scan方式被應(yīng)用,如果不考慮跨層交流,大量臨近信息在單層中依舊有損失。
針對(duì)于此,LinGen在TE分支中應(yīng)用了TEmporal Swin Attention(TESA):它是一種特殊的3D window attention,窗口范圍在不同層中會(huì)滑動(dòng),每一個(gè)窗口都很小,并且窗口大小不隨視頻分辨率和長(zhǎng)度(即3D tensor的大小)的變化而變化。
這是因?yàn)門ESA僅用來(lái)處理最臨近的信息,這一固定的窗口大小也使得TESA實(shí)現(xiàn)了相對(duì)3D tensor中token數(shù)的線性復(fù)雜度。
作為額外的補(bǔ)充,LinGen還在MA分支中引入了review tokens。它被用以增強(qiáng)視頻中極長(zhǎng)程的一致性,例如在60秒視頻的結(jié)尾復(fù)現(xiàn)視頻前幾秒消失的人。它把待處理video tensor的概覽提前寫入Mamba的hidden state memory中,為后續(xù)的視頻處理提供幫助。
評(píng)估:遠(yuǎn)超基線,對(duì)標(biāo)SOTA
從人類評(píng)測(cè)和模型自動(dòng)評(píng)測(cè)兩個(gè)角度將LinGen與已有的先進(jìn)視頻生成模型、以及DiT baseline進(jìn)行比較。
無(wú)論是人類評(píng)測(cè)的結(jié)果,還是在VBench上的自動(dòng)評(píng)測(cè)的結(jié)果,都顯示LinGen與先進(jìn)的商業(yè)模型Kling、Runway Gen-3生成的視頻質(zhì)量接近,并且遠(yuǎn)勝于OpenSora v1.2。
可以看到,在FLOPs方面,當(dāng)生成17秒、34秒和68秒長(zhǎng)度的512p視頻時(shí),LinGen-4B相對(duì)于DiT-4B分別實(shí)現(xiàn)了5×、8×和15×的加速;
在延遲方面,當(dāng)在單個(gè)H100上生成512p和768p的17秒視頻時(shí),LinGen-4B相對(duì)于DiT-4B分別實(shí)現(xiàn)了2.0×和3.6×的加速;
當(dāng)生成17秒、34秒和68秒長(zhǎng)度的512p視頻時(shí),LinGen-4B相對(duì)于DiT-4B分別實(shí)現(xiàn)了2.0×、3.9×和11.5×的延遲加速。
這說(shuō)明LinGen具有線性復(fù)雜度,可以在單卡上實(shí)現(xiàn)分鐘級(jí)視頻生成,速度遠(yuǎn)快于DiT。與相同大小的DiT相比,LinGen可實(shí)現(xiàn)推理速度11倍以上的提升。
另外,LinGen和相同大小、在相同數(shù)據(jù)集上以相同training recipe訓(xùn)練的DiT baseline相比,在視頻質(zhì)量和文字-視頻一致性上取得全面領(lǐng)先。相比起DiT,LinGen可以更快地適應(yīng)更長(zhǎng)的token序列。
通常認(rèn)為自注意力模塊的線性替代是對(duì)完整自注意力的近似,雖然在速度上有顯著優(yōu)勢(shì),但在模型性能上往往略遜一籌,而LinGen打破了這個(gè)慣有的看法。
在整個(gè)預(yù)訓(xùn)練過(guò)程中,模型從低分辨率圖像生成開始,學(xué)習(xí)低分辨率視頻生成,再不斷增加所生成視頻的分辨率和長(zhǎng)度,所處理的token數(shù)增長(zhǎng)了上千倍。
而在從少token數(shù)的任務(wù)遷移到多token數(shù)的任務(wù)時(shí),LinGen的適應(yīng)性遠(yuǎn)強(qiáng)于DiT(a圖中是從256x256分辨率視頻生成遷移到512x512分辨率視頻生成任務(wù)時(shí)的loss curve),這可能是受益于Mamba對(duì)于長(zhǎng)序列的高適應(yīng)性,這一特征已經(jīng)在語(yǔ)言任務(wù)上被觀察到。
為了進(jìn)一步驗(yàn)證這里推理,選取這一預(yù)訓(xùn)練階段的早期checkpoint進(jìn)行比較,發(fā)現(xiàn)LinGen比DiT的win rate優(yōu)勢(shì)變得更加顯著。這暗示了雖然LinGen在任務(wù)遷移的早期能大幅領(lǐng)先DiT,但是這種優(yōu)勢(shì)隨著預(yù)訓(xùn)練的進(jìn)行,在不斷減小。
盡管如此,在訓(xùn)練資源有限的情況下,LinGen在預(yù)訓(xùn)練的極長(zhǎng)一段時(shí)間內(nèi)仍舊能對(duì)DiT保持優(yōu)勢(shì)。
項(xiàng)目主頁(yè):https://lineargen.github.io/
論文鏈接:https://arxiv.org/abs/2412.09856
項(xiàng)目代碼:https://github.com/jha-lab/LinGen