擴(kuò)散語(yǔ)言模型九倍推理加速!上海交大:KV Cache并非自回歸模型的專屬技巧
首個(gè)用于加速擴(kuò)散式大語(yǔ)言模型(diffusion-based Large Language Models, 簡(jiǎn)稱 dLLMs)推理過(guò)程的免訓(xùn)練方法。
上海交通大學(xué)EPIC Lab團(tuán)隊(duì)提出了一種無(wú)需訓(xùn)練、即插即用的高效推理緩存機(jī)制:dLLM-Cache。
其核心思想在于,在一個(gè)多步去噪過(guò)程中,復(fù)用相鄰時(shí)間步上變化較小的特征,僅更新那些變化較大的特征,從而實(shí)現(xiàn)了計(jì)算量的大幅降低,并保持了原有的生成質(zhì)量。
圖1 不同dLLMs使用dLLM–Cache和不使用dLLM–Cache在速度和質(zhì)量上的對(duì)比
dLLM-Cache具有幾個(gè)重要的亮點(diǎn):
1. 訓(xùn)練無(wú)關(guān),即插即用。dLLM-Cache完全在推理過(guò)程中工作,無(wú)需修改模型參數(shù)或重訓(xùn)練。dLLM-Cache可以在完全不損失模型輸出質(zhì)量的前提下,帶來(lái)最高9.1倍的推理速度提升 。
2. 通用于主流dLLM架構(gòu),如LLaDA、Dream以及LLaDA-V、MMaDA、Dimple等多模態(tài)模型。
3. 在推理過(guò)程中,首次識(shí)別出了prompt部分的Transformer中間層特征(Key、Value、Attention output、FFN output)長(zhǎng)期穩(wěn)定,而response部分僅有一小部分tokens的特征變化較大,為緩存特征并后續(xù)復(fù)用提供了理論基礎(chǔ)。
4. 獨(dú)創(chuàng)了以V-verify機(jī)制為核心的選擇更新策略。以Value向量的變化為選擇基準(zhǔn),成功識(shí)別出了response部分變化較大的那些tokens,通過(guò)僅更新這些特征,摒棄了高達(dá)75%的冗余計(jì)算。
本論文共同第一作者劉知遠(yuǎn)和楊奕存是哈爾濱工業(yè)大學(xué)2022級(jí)本科生,目前在上海交通大學(xué)EPIC Lab進(jìn)行科研實(shí)習(xí),師從張林峰助理教授,主要研究方向?yàn)楦咝疃葘W(xué)習(xí),此前曾在CVPR2025上收獲滿分論文。
接下來(lái),我們一起來(lái)看看該研究的細(xì)節(jié)。
研究動(dòng)機(jī)
基于擴(kuò)散的大語(yǔ)言模型正成為語(yǔ)言生成領(lǐng)域最受關(guān)注的新范式之一。隨著模型架構(gòu)的發(fā)展、去噪算法的優(yōu)化以及Masked Diffusion在語(yǔ)言建模中逐步展現(xiàn)出與自回歸模型不同的建模能力,這類模型正在逐步成為挑戰(zhàn) GPT 等主流模型的重要力量。
以LLaDA、Dream為代表的擴(kuò)散語(yǔ)言模型,基于迭代去噪的生成過(guò)程,不再依賴嚴(yán)格的自回歸因果結(jié)構(gòu),天然支持雙向建模、全局依賴和反向推理等能力,已經(jīng)在“逆轉(zhuǎn)詛咒”、數(shù)學(xué)推理等任務(wù)上展現(xiàn)出領(lǐng)先性能。
然而,這種范式的優(yōu)勢(shì)也伴隨著巨大的代價(jià)。為了確保生成的質(zhì)量,dLLMs在推理過(guò)程中通常需要執(zhí)行長(zhǎng)達(dá)數(shù)百步的去噪迭代,每一步都需重新計(jì)算attention、FFN等所有層的特征,計(jì)算量相當(dāng)于多次完整前向傳播。這為dLLMs的推理效率帶來(lái)了嚴(yán)重的瓶頸,制約了其實(shí)際部署。更重要的是,主流的加速手段如用于自回歸模型的KV Cache,由于不兼容雙向注意力架構(gòu),在dLLMs中完全失效。
與傳統(tǒng)的自回歸語(yǔ)言模型不同,dLLMs不再依賴順序生成下一個(gè)token,而是采用隨機(jī)遮蔽(mask) + 逐步還原的方式建模token分布,這種機(jī)制使得模型具備天然的雙向建模能力,理論上能夠更好地處理逆向邏輯、長(zhǎng)距離依賴等任務(wù)。
LLaDA 等模型已經(jīng)在多個(gè)基準(zhǔn)任務(wù)中超越主流ARMs,尤其在“逆轉(zhuǎn)詛咒”上明顯勝出。
然而,這種擴(kuò)散式推理帶來(lái)一個(gè)嚴(yán)重的挑戰(zhàn):為了確保生成質(zhì)量,dLLMs通常需要上百步的去噪迭代,每一步都需全量計(jì)算Attention、FFN等模塊,導(dǎo)致其推理速度相比ARMs慢一個(gè)數(shù)量級(jí),落地成本高。同時(shí),ARMs 通用的加速方法如KV-Cache因dLLMs的雙向注意力設(shè)計(jì)而無(wú)法兼容。這些造成了dLLMs在推理時(shí)既慢又缺乏加速手段的現(xiàn)象。這正是 dLLM-Cache所要破解的核心問(wèn)題。
方法簡(jiǎn)介
本文作者仔細(xì)研究了dLLMs推理的中間特征變化過(guò)程,發(fā)現(xiàn)如下關(guān)鍵現(xiàn)象:
圖2 dLLM中兩個(gè)相鄰去噪步驟之間的Key、Value、Attention Output和FFN Output的余弦相似度
Prompt tokens的特征在整個(gè)去噪過(guò)程中基本保持穩(wěn)定,每一步都重新計(jì)算這些特征是完全不必要且浪費(fèi)計(jì)算資源的;
Response tokens多數(shù)變化很小,僅少部分變化劇烈,全量計(jì)算所有response tokens存在冗余。
由此,問(wèn)題轉(zhuǎn)化為了如何高效識(shí)別出這些變化劇烈的response tokens。
圖3 Response tokens的K或V變化與其他特征變化的相關(guān)性
本文作者首創(chuàng)性得提出了V-verify機(jī)制。它的提出源于另一項(xiàng)重要的發(fā)現(xiàn):作者量化了response tokens的底層特征(Key, Value向量)的變化與其上層復(fù)雜特征(Attention Output, FFN Output)的變化之間的關(guān)系,結(jié)果顯示它們存在著極強(qiáng)的正相關(guān)性,皮爾遜相關(guān)系數(shù)最高可達(dá)0.944。
這意味著,一個(gè)token底層的Value向量是否發(fā)生變化,是其整體狀態(tài)是否發(fā)生改變的一個(gè)極佳的、且計(jì)算成本極低的“指示器”。
基于以上這些關(guān)鍵的觀察,本文作者提出了dLLM-Cache ,具體的框架設(shè)計(jì)如下:
圖4 dLLM-Cache方法整體pipeline
Prompt緩存:長(zhǎng)間隔重用
對(duì)于prompt部分,作者設(shè)計(jì)了長(zhǎng)間隔Prompt緩存,每隔Kp步(在實(shí)驗(yàn)中一般設(shè)置為100)更新一次prompt的Key、Value、Attention Output、FFN Output,其余步驟全部復(fù)用先前結(jié)果。這樣避免了對(duì)穩(wěn)定不變的特征的重復(fù)計(jì)算,大幅減少了計(jì)算量。
Response緩存:自適應(yīng)部分更新
對(duì)生成目標(biāo)response區(qū)域,由于response tokens的特征并不是一直保持穩(wěn)定不變的,作者設(shè)計(jì)了較短間隔的Response緩存,每隔Kr步(在實(shí)驗(yàn)中一般設(shè)置為8左右)全量更新一次response的Key、Value、Attention Output、FFN Output,在其余的步驟,作者提出了基于V-verify的自適應(yīng)緩存策略:
- 在每個(gè)去噪步驟,首先計(jì)算所有response tokens最新的Value向量。
- 然后,通過(guò)計(jì)算新Value向量與緩存中舊Value向量的余弦相似度,將余弦相似度作為每個(gè)response tokens的一個(gè)“變化分”。
- 只選出“變化分”最高(即相似度最低)的極少數(shù)tokens(例如,變化最劇烈的25%),將它們標(biāo)記為“待更新” 。
- 最后,模型只對(duì)這些被標(biāo)記的“待更新”tokens,進(jìn)行完整的特征重計(jì)算。而其余75%的“穩(wěn)定”tokens,則繼續(xù)高效地從緩存中復(fù)用其特征。
通過(guò)這種“長(zhǎng)間隔”與“自適應(yīng)”相結(jié)合的緩存策略,dLLM-Cache在Transformer的每一層都實(shí)現(xiàn)了計(jì)算量的極致優(yōu)化,且整個(gè)過(guò)程無(wú)需任何額外訓(xùn)練,做到了真正的即插即用。
3 實(shí)驗(yàn)結(jié)果
本文在 LLaDA 8B和Dream 7B兩大代表性的開(kāi)源dLLM的基礎(chǔ)版與指令微調(diào)版上,針對(duì)數(shù)學(xué)與科學(xué)、通用任務(wù)、代碼生成三大領(lǐng)域的8個(gè)主流基準(zhǔn)測(cè)試,對(duì)dLLM-Cache的有效性進(jìn)行了嚴(yán)苛的檢驗(yàn) 。評(píng)估維度不僅包括推理速度(TPS)和計(jì)算效率(FLOPs),更核心的是模型性能得分(Score),以確保加速不是以犧牲模型能力為代價(jià)。
本文在LLaDA 8B的基礎(chǔ)版和指令微調(diào)版上都部署了dLLM-Cache,下圖的實(shí)驗(yàn)結(jié)果充分展示了其強(qiáng)大的加速能力和卓越的生成質(zhì)量保持。在幾乎所有的基準(zhǔn)測(cè)試中,都達(dá)到了5倍以上的加速效果,且在絕大部分情況下,生成質(zhì)量都沒(méi)有降低,甚至有輕微的提升。特別是當(dāng)面對(duì)LongBench任務(wù)時(shí),prompt的穩(wěn)定性帶來(lái)了更顯著的加速效果,在HotpotQA上實(shí)現(xiàn)了高達(dá)9.1倍的無(wú)損加速。
圖5 dLLM-Cache在LLaDA模型上的效果
為了進(jìn)一步證明dLLM-Cache的通用性和魯棒性,作者將其無(wú)縫遷移至另一款架構(gòu)略有不同的dLLM——Dream 7B上。下圖的實(shí)驗(yàn)結(jié)果再次印證了dLLM-Cache方法的有效性,充分說(shuō)明了其通用于主流dLLM架構(gòu)。
圖6 dLLM-Cache在Dream模型上的效果
作者還將dLLM和主流的基于ARM的LLM進(jìn)行了對(duì)比,下圖展示了LLaDA 8B與LLaMA3 8B在GSM8K任務(wù)上的比較。結(jié)果顯示,原始的LLaDA在準(zhǔn)確率上以近20個(gè)點(diǎn)的巨大優(yōu)勢(shì)領(lǐng)先于LLaMA3,但在推理速度上卻遠(yuǎn)不及。然而,在使用了本文的dLLM-Cache之后,LLaDA的推理速度獲得了超過(guò)3.3倍的提升,首次超過(guò)了LLaMA3的推理速度。這一結(jié)果有力地證明,本文提出的dLLM-Cache能夠讓dLLMs在保持其顯著準(zhǔn)確率優(yōu)勢(shì)的同時(shí),獲得與ARMs相當(dāng)競(jìng)爭(zhēng)力的推理速度。
圖7 使用dLLM-Cache的dLLM vs 使用KV-Cache的ARM
論文鏈接: https://github.com/maomaocun/dLLM-cache/blob/main/asset/paper.pdf
代碼已開(kāi)源: https://github.com/maomaocun/dLLM-Cache