DeepSeek的多頭潛在注意力(MLA)和及其11種KV-Cache技巧演進(jìn)大總結(jié) 原創(chuàng)
引言
本文將探討KV-Cache如何通過在內(nèi)存使用和計(jì)算時(shí)間之間進(jìn)行巧妙的權(quán)衡,使像ChatGPT和DeepSeek這樣的語言模型在生成文本時(shí)更快。
總結(jié)11篇最近的研究論文,歸納三大類:token選擇、后處理壓縮技術(shù)和架構(gòu)重新設(shè)計(jì)。包括DeepSeek的多頭潛在注意力(MLA),這些論文在這一基本思想的基礎(chǔ)上,進(jìn)一步提高了大型語言模型(LLM)推理的時(shí)間效率。
一、思考
為什么文本生成如此緩慢
讓我們從一個(gè)簡單的類比開始。想象你在寫一個(gè)故事,每寫一個(gè)新詞,你都需要重新閱讀到目前為止的整個(gè)故事以保持一致性。故事越長,重新閱讀的時(shí)間就越長。這正是大型語言模型在文本生成時(shí)所面臨的問題。
自注意力的基本構(gòu)建塊
現(xiàn)代語言模型的核心是一種稱為自注意力的機(jī)制。對于一個(gè)由n個(gè)標(biāo)記(大致對應(yīng)單詞)組成的序列,每個(gè)標(biāo)記都需要“查看”或“關(guān)注”所有其他標(biāo)記以理解上下文。
這種查看一切的過程的計(jì)算成本隨著序列長度的增長而增長:
- 對于n個(gè)標(biāo)記,每個(gè)標(biāo)記都需要查看所有nn個(gè)標(biāo)記
- 這意味著成本與n x n = n2成正比
- 用數(shù)學(xué)符號表示,我們將其寫為O(n2)的復(fù)雜度
真正的問題:一次生成一個(gè)標(biāo)記
當(dāng)語言模型生成文本時(shí),它一次生成一個(gè)標(biāo)記,這就是事情變得計(jì)算密集的地方:
- 第一個(gè)標(biāo)記:查看1個(gè)標(biāo)記(成本:O(12))
- 第二個(gè)標(biāo)記:查看2個(gè)標(biāo)記(成本:O(22))
- 第三個(gè)標(biāo)記:查看3個(gè)標(biāo)記(成本:O(32))
- 以此類推,直到第n個(gè)標(biāo)記:查看n個(gè)標(biāo)記(成本:O(n2))
如果我們將生成長度為的序列的所有這些成本加起來,我們得到:
這種O(n3)的成本意味著隨著文本的增長,生成時(shí)間會極其迅速地增長。例如,生成兩倍長的序列大約需要八倍的時(shí)間!顯然,我們需要一個(gè)更好的方法。
解決方案:鍵值(KV)緩存
KV 緩存背后的關(guān)鍵是,我們正在做大量冗余工作。在生成每個(gè)新標(biāo)記時(shí),我們會重新計(jì)算之前已經(jīng)處理過的所有先前標(biāo)記。讓我們看看如何解決這個(gè)問題。
什么是鍵值緩存?
可以將 KV 緩存想象成一個(gè)智能記事本,我們會在第一次看到每個(gè) token 時(shí)記下有關(guān)它的重要信息。對于每個(gè) token,我們計(jì)算并存儲兩件事:
- 鍵(k):可以將其視為一種尋址機(jī)制——它有助于確定此標(biāo)記與未來標(biāo)記的相關(guān)性
- 值(v):可以將其視為當(dāng)此標(biāo)記被發(fā)現(xiàn)相關(guān)時(shí)實(shí)際使用的信息
從數(shù)學(xué)上,我們計(jì)算這些為:
- 鍵:k = xWk(其中是x標(biāo)記,Wk是一個(gè)學(xué)習(xí)到的變換)
- 值:v = xWv(其中Wv是另一個(gè)學(xué)習(xí)到的變換)
在生成一個(gè)新標(biāo)記時(shí),我們使用它的查詢(計(jì)算方式類似于鍵)通過將其與所有存儲的鍵進(jìn)行比較來在我們的緩存中找到相關(guān)信息。然后使用匹配的值來幫助生成標(biāo)記。
KV緩存如何加速
有了KV緩存,處理過程變得更加高效:
- 當(dāng)我們遇到一個(gè)新token時(shí),只需要計(jì)算它的key和value一次
- 對于所有后續(xù)的token,我們可以直接從緩存中查找這些預(yù)計(jì)算的值
- 這意味著每個(gè)新token只需要做少量新的計(jì)算,而不是重新做所有之前的計(jì)算
顯然有一個(gè)權(quán)衡:
- 我們需要更多的內(nèi)存來存儲所有的keys和values。對于一個(gè)具有:
L層
H注意力頭
序列長度n
key/value維度dk,總的內(nèi)存開銷為L x H x n x dk x 2值(這個(gè)2是因?yàn)樾枰鎯eys和values)。這會隨著序列長度n線性增長(O(n)),但對于大模型來說,常數(shù)因子可能非常大。
- 但作為回報(bào),我們將計(jì)算成本從O(n3)降低到O(n2)。
要理解為什么是O(n2),讓我們看一下每一步的成本:
- 第一步:處理一個(gè)token ->成本O(1)
- 第二步:處理一個(gè)新token + 查找1個(gè)緩存的token -> 成本O(2)
- 第三步:處理一個(gè)新token + 查找2個(gè)緩存的token -> 成本O(3)
- 依此類推...
將這些加起來:
O(1 + 2 + 3 + ... + n) = O(n2)
這相比O(n3)是一個(gè)顯著的改進(jìn)!雖然我們?nèi)匀恍枰霾榭此星懊娴膖okens的基礎(chǔ)工作<O(n2)>,但我們避免了每一步都進(jìn)行昂貴的重新計(jì)算。
內(nèi)存挑戰(zhàn):為什么我們需要更好的解決方案
雖然KV緩存是一個(gè)強(qiáng)大的優(yōu)化手段,但它伴隨著顯著的內(nèi)存開銷。讓我們通過一個(gè)具體的例子來看看,使用像Llama3 70B這樣的現(xiàn)代大語言模型:
- L = 80層
- H = 64注意力頭
- B = 8批量大小為8個(gè)序列
- dk= 128key/value維度
- 16位精度
處理一個(gè)批量(8個(gè)序列,每個(gè)序列1000個(gè)token)所需的內(nèi)存為:
L x H x B x n x dk x 2 x 2字節(jié)=80 x 64 x 8 x 1000 x 128 x 2 x 2字節(jié)=20.97GB
這種巨大的內(nèi)存使用帶來了幾個(gè)挑戰(zhàn):
- 隨著序列長度線性增長
- 與批量大小成倍增長,支持并行處理
- 限制了我們可以處理的最大上下文長度
- 限制了在內(nèi)存受限設(shè)備上的部署
這些挑戰(zhàn)激發(fā)了研究界的一波創(chuàng)新,導(dǎo)致了各種優(yōu)化KV緩存使用的技術(shù)。接下來,將探討這些前沿的解決方案。
二、如何改善傳統(tǒng)的KV緩存?
以下論文代表了KV緩存優(yōu)化的關(guān)鍵創(chuàng)新。我們將通過三大主要方法來探索它們:token選擇、后處理壓縮技術(shù)和架構(gòu)重設(shè)計(jì)。
2.1 Token 選擇和修剪方法(Token Selection and Pruning Approaches)
1) Heavy-Hitter Oracle (H2O)
H2O 引入了在KV緩存中識別和保留重要token的概念:
- 重型Token(Heavy-Hitter Tokens):H2O 識別在生成過程中具有最高累計(jì)注意力分?jǐn)?shù)的token,這些token遵循冪律分布。這些token對于模型的功能至關(guān)重要,因此在緩存中優(yōu)先處理。
- 動態(tài)次模撤銷(Dynamic Submodular Eviction):該方法將緩存管理問題框架化為一個(gè)優(yōu)化問題,目標(biāo)函數(shù)為次模函數(shù)F(S),用于量化token集合的S重要性:
確保每次最多只移除一個(gè)token。這個(gè)貪心算法在計(jì)算上高效,并在次模約束下保證接近最優(yōu)的性能。
- 結(jié)果:通過該方法,KV緩存大小減少了5倍,幾乎沒有精度損失,并且吞吐量提升了高達(dá)29倍。
2) StreamLLM
- 作者觀察到注意力匯聚(Attention Sinks)現(xiàn)象:解碼過程中,初始token充當(dāng)自然的注意力錨點(diǎn)。
- 如果沒有這些注意力匯聚的token,傳統(tǒng)窗口注意力方法的性能會下降。
- 基于這一觀察,他們引入了滾動緩存(Rolling Cache),它保留了初始token,并處理最近的上下文,從而實(shí)現(xiàn)了無限長度序列的處理。
- 他們還展示了這些匯聚token可以通過訓(xùn)練獲得,作為專用的注意力錨點(diǎn),從而減少對多個(gè)初始token的依賴。
3) Value-Aware Token Pruning (VATP)
VATP 擴(kuò)展了 H2O 的 token 重要性概念,考慮了注意力模式和價(jià)值向量的屬性:
性能與效率:
- 在16個(gè) LongBench 任務(wù)中,VATP 在12-14個(gè)任務(wù)中超越了 H2O 和 Scissorhands 等基準(zhǔn)。
- 在保持最小性能損失的情況下,實(shí)現(xiàn)了50%的有效壓縮。
- 引入的計(jì)算開銷幾乎可以忽略不計(jì),并且與 Scissorhands 集成時(shí)兼容 FlashAttention。
2.2 后處理壓縮技術(shù)(Post-hoc Compression Techniques)
這些方法壓縮或優(yōu)化KV緩存,同時(shí)保持標(biāo)準(zhǔn)的Transformer架構(gòu)。
4) Adaptive KV Compression (FastGen)
FastGen 通過觀察運(yùn)行時(shí)的注意力模式引入了自適應(yīng)壓縮:
自適應(yīng)壓縮策略:
5) 動態(tài)內(nèi)存壓縮(DMC)
DMC 引入了自適應(yīng)的 token 合并:
6)L2范數(shù)基礎(chǔ)的壓縮
本文提出了一個(gè)令人驚訝的觀察:緩存 KV 對的L2范數(shù)與注意力分?jǐn)?shù)之間存在明確的相關(guān)性,低L2范數(shù)的鍵嵌入通常會導(dǎo)致解碼時(shí)的高注意力分?jǐn)?shù)。因此,提出了一個(gè)簡單但有效的壓縮目標(biāo):
2.3 體系結(jié)構(gòu)重設(shè)計(jì)
這些方法改變了 Transformer 架構(gòu),以更高效地處理 KV 緩存,通常將壓縮直接集成到架構(gòu)中。
7) 多查詢注意力(MQA)
- 核心思想:MQA 通過共享單個(gè)鍵值頭跨所有查詢頭來減少 KV 緩存大小,替代傳統(tǒng)的多頭注意力(MHA):
K = XWK , V = XWV
其中K和V是共享的鍵和值投影。
- 優(yōu)點(diǎn):將 KV 緩存大小減少了H(注意力頭的數(shù)量),顯著降低了內(nèi)存帶寬開銷。
- 權(quán)衡:雖然 MQA 更快,但在需要多樣化注意力模式的任務(wù)中,通常會遭遇質(zhì)量下降。
8) 分組查詢注意力(GQA)
- 核心思想:GQA 在完全多頭注意力和 MQA 之間進(jìn)行插值,提供了推理速度和模型質(zhì)量之間的可擴(kuò)展權(quán)衡。它將查詢頭分為G組,每組共享一個(gè)單獨(dú)的鍵值頭:
9) 多頭潛在注意力(MLA)
DeepSeek的多頭潛在注意力(MLA)采用了一種新穎的方法來減少KV緩存開銷。雖然MQA和GQA通過頭共享來實(shí)現(xiàn)這一目標(biāo),MLA則采用低秩潛在壓縮技術(shù),在保持多頭注意力的優(yōu)點(diǎn)的同時(shí),減少了KV緩存的大小。
- MLA通過將鍵(keys)和值(values)壓縮成低維度的潛在向量,來減少KV緩存的大小。
- 它將鍵值嵌入(key-value embeddings)降投到一個(gè)壓縮的潛在空間:
10) SnapKV
11) 只緩存一次(YOCO)
YOCO修改了Transformer架構(gòu)以優(yōu)化緩存:
- 全局緩存:使用解碼器-解碼器設(shè)計(jì),只有一個(gè)共享的KV緩存。
- 復(fù)雜度減少:將內(nèi)存從O(N x L)減少到O(N + L),其中N是序列長度,L是層數(shù)。
- 高效注意力:自解碼器采用滑動窗口注意力或門控保留機(jī)制,使內(nèi)存使用保持恒定(O(C),其中C是小窗口大小)。
結(jié)論
KV-Cache技術(shù)是將Transformer模型擴(kuò)展和優(yōu)化到實(shí)際應(yīng)用中的核心。像動態(tài)逐出、壓縮和結(jié)構(gòu)化近似等創(chuàng)新,持續(xù)推動著在長上下文或資源受限的場景中實(shí)現(xiàn)更高效的技術(shù)。KV-Cache仍然是一個(gè)活躍的研究領(lǐng)域,既提供了理論上的見解,也帶來了實(shí)際的改進(jìn)。
公眾號大模型自然語言處理 作者:余俊暉
原文鏈接:??https://mp.weixin.qq.com/s/7j9sVIlJQDPnji9bas09ig???
