DeepSeek的MLA,任意大模型都能輕松遷移了
復旦 NLP 實驗室博士后紀燾是這篇文章的第一作者,研究方向為大模型高效推理、多模態大模型,近期代表工作為首個NoPE外推HeadScale、注意力分塊外推LongHeads、多視覺專家大模型MouSi,發表ACL、ICLR、EMNLP等頂會頂刊論文 20 余篇。
DeepSeek-R1 作為 AI 產業顛覆式創新的代表轟動了業界,特別是其訓練與推理成本僅為同等性能大模型的數十分之一。多頭潛在注意力網絡(Multi-head Latent Attention, MLA)是其經濟推理架構的核心之一,通過對鍵值緩存進行低秩壓縮,顯著降低推理成本 [1]。
然而,現有主流大模型仍然基于標準注意力架構及其變種(e.g., MHA, GQA, MQA),推理成本相比 MLA 呈現顯著劣勢。使預訓練的任意 LLMs 快速遷移至 MLA 架構而無需從頭預訓練,這既有重大意義又具有挑戰性。
復旦 NLP 實驗室、華東師大、上海 AI Lab、海康威視聯合提出 MHA2MLA 框架,通過部分 RoPE 保留(Partial-RoPE)和鍵值聯合表示低秩近似(Low-rank Approximation)兩個關鍵步驟,成功將任意 MHA/GQA 架構遷移到 MLA。
目前,MHA2MLA 已位列??alphaXiv 熱度榜??
復旦 NLP 實驗室博士后紀燾為第一作者,副研究員桂韜為通訊作者。
- 論文標題:Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs
- 論文鏈接:https://arxiv.org/abs/2502.14837
- 開源代碼:https://github.com/JT-Ushio/MHA2MLA
論文概覽
本文聚焦如何將預訓練的基于 MHA/GQA 的大語言模型高效遷移到 DeepSeek 提出的經濟推理架構 —— 多頭潛在注意力(MLA)。
MHA 與 MLA 在多處存在差異,使得 MHA2MLA 極具挑戰:
- 位置編碼不同:MHA 采用全維度位置編碼(PE),MLA 僅少量維度采用 PE,剩余維度則 PE 無關
- 緩存對象不同:MHA 緩存分離的鍵向量及值向量,MLA 緩存帶 PE 的鍵向量及 PE 無關的鍵值聯合低維表示向量
- 參數矩陣不同:MHA 包含查詢、鍵、值三個線性變換矩陣,MLA 則更加復雜、多達七個目的不同的線性變換矩陣
- 運算形式不同:MHA 的運算受限于訪存瓶頸,MLA 則能通過矩陣吸收等優化實現更高的訪存效率
本文提出的 MHA2MLA 為了最大化利用 MHA 預訓練參數矩陣并對齊 MLA 的緩存對象和運算形式,首先通過部分 RoPE 保留(Partial-RoPE)分離出 PE 相關表示(少量維度,如 1/8)和 PE 無關表示(大量維度),其中 PE 相關的鍵向量對齊 MLA。其次拼接值的變換矩陣(W_v)和 PE 無關的鍵的變換矩陣(W_{k, nope}),并進行 SVD 分解得到降維變換矩陣和升維變化矩陣,中間的鍵值聯合低秩表示對齊 MLA,完成了緩存對象的對齊以及運算形式的對齊。
在 135M~7B 上的實驗表明,僅需使用預訓練數據的 0.3% 到 0.6% 進行高效微調,即可基本還原架構遷移帶來的性能損失。并且 MHA2MLA 還能結合其他高效推理技術,例如結合 4-bit KV 緩存量化,Llama2-7B 減少了 92.19% KV 緩存,而 LongBench 上的性能僅下降 0.5%。
部分 RoPE 保留(Partial-RoPE)
為了實現從標準的 MHA(多頭注意力機制)到 MLA(多頭潛在注意力機制)的遷移,作者提出了部分 RoPE 微調(partial-RoPE finetuning)策略,該策略通過從大量維度中移除 RoPE(旋轉位置編碼)并將其轉換為 NoPE(無位置編碼)來解決 MLA 和 RoPE 沖突的問題。
作者主要嘗試了四種移除 RoPE 的策略:1)保留高頻位置信息 S_high,該方法最簡單直接,保留了局部語義特征相關的高頻特征 [2];2)保留低頻位置信息 S_low,與保留高頻位置信息的策略形成對比,檢驗低頻成分在語義理解任務中的潛在作用;3)均勻采樣策略 S_uniform,等間隔均勻采樣頻率保留位置頻率;4)使用查詢、鍵向量范數乘積 (2-norm) 近似注意力貢獻值 [2] 的篩選策略 S_{2-norm},針對每個注意力頭,計算所有頻率的平均 2-norm 分數,隨后選擇得分較高的頻率保留位置信息。該策略能自適應識別對模型性能關鍵的特征頻率。
Partial-RoPE 的消融實驗表明:1)保留低頻位置信息的 S_low 導致了最大的性能損失,保留高頻位置信息的 S_high 導致的性能損失明顯小于保留低頻,說明了高頻維度的重要性;2)S_uniform 和 S_{2-norm} 均展現出更優的性能,分別在 135M 模型和 1.7B 模型上取得了最少的性能損失。最終作者選擇 S_{2-norm} 作為默認配置,是因為注意力貢獻分數較低的維度在結合低秩近似時損失更少。
鍵值聯合表示低秩近似
移除了大量維度的 RoPE 之后,MHA2MLA 就可以對值向量和 PE 無關的鍵向量進行低秩近似,從而大幅減少緩存空間。為最大化保留預訓練知識,本文提出兩種基于奇異值分解 (SVD) 的投影矩陣初始化策略:1)SVD_split,分別對矩陣進行低秩分解,保持各自的表征特性;2)SVD_joint,考慮鍵值矩陣之間的關聯性,參數矩陣拼接后整體進行低秩分解。
消融實驗表明:無論是在 GQA 基座還是 MHA 基座上,SVD_joint 方法始終優于 SVD_split 方法。
實驗結果
作者在多種規模的語言模型(SmolLM-135M/360M/1B7 和 Llama2-7B)以及不同壓縮比例的配置下評估了所提出的方法。實驗表明:1)相同微調設置下,壓縮比例越高,性能損失越大,特別是對于兩個 GQA 模型;2)相同壓縮比例下,原始模型參數越多,性能損失越小,揭示了 MHA2MLA 的潛在 scaling law。3)MHA2MLA 的微調數據量僅需預訓練數據的 0.3%~0.6%,避免了從頭預訓練 MLA 模型的高昂成本。
作者在 LongBench 長文本生成任務中評估了結構遷移后的 Llama2-7B 模型,將 KV 緩存量化作為基準對比方案。實驗表明,MHA2MLA 能在 d_{kv}=16 的情況下實現與 2-bit 量化相同的壓縮比例(87.5%),同時僅損失一半的性能(-3.0% vs. -6.2%);進一步結合 4-bit 量化后,不僅壓縮比例超過 2-bit 量化,性能損失也都優于所有 2-bit 的基線方法,例如 92.19% 壓縮比例僅掉 0.5%,96.87% 壓縮比例僅掉 3.2%,證明了 MHA2MLA 能顯著減少推理時的訪存瓶頸。
總結與展望
本文主要研究如何將基于 MHA 的預訓練 LLMs(或其變體)適配為 KV 緩存高效的 MLA 架構,以顯著降低推理時的訪存瓶頸。通過精心的架構設計,MHA2MLA 僅需 0.3% 至 0.6% 預訓練數據。該框架展現了與現有壓縮技術的強兼容性,同時保持了常識推理和長上下文處理能力,為部署資源高效的 LLMs 提供了一條實用路徑。
作者提到該研究受限于硬件條件,當前實驗未能覆蓋 Llama3 等需 128K 長上下文微調的模型,也未突破 7B 參數規模的驗證瓶頸。擴展至更多的基座將作為未來工作之一。作者還計劃結合參數高效微調策略,進一步降低架構遷移過程中的參數更新規模。