ICML 2025 | 全局池化+局部保留,CCA-Attention為LLM長文本建模帶來突破性進展
琶洲實驗室、華南理工大學聯合推出關鍵上下文感知注意力機制(CCA-Attention),實現超長文本的高效上下文建模。在 128K 超長序列上下文建模任務中,CCA-Attention 的推理速度是標準自注意力機制的 7.9 倍,同時鍵值緩存(KV Cache)顯存占用減少 93%,性能全面優于現有高效注意力方法。
- 論文標題:Core Context Aware Transformers for Long Context Language Modeling
- 論文鏈接:https://arxiv.org/pdf/2412.12465
- 代碼鏈接:https://github.com/chenyaofo/CCA-Attention
- 發布時間:2024年12月17日
該成果已被 ICML 2025 接收,最早于 2024 年 12 月 17 日提交至 ArXiv,早于 DeepSeek NSA 和 Kimi MoBA 公開。CCA-Attention 不僅速度快、資源占用低,更在上下文建模的精準度和效率上樹立了新標桿,為長文本處理注入全新動力。
引言
近期研究 [1, 2, 3] 發現,LLMs 中的大多數層的注意力權重主要集中在少數 token 上,表現出顯著的稀疏性(見圖 1)。這一發現啟示我們可以借助這種稀疏特性,降低注意力機制的計算復雜度。
圖 1: LLaMA2-7B 模型中注意力權重的可視化,陰影越深表示注意力權重越高。最后一個 token 僅對上下文少數幾個 token 有著較高的注意力權重,即注意力權重具有顯著的稀疏性。
現有稀疏注意力方法 [5, 6, 7] 通常通過預定義的稀疏模式來降低計算成本。然而,在問答任務中,關鍵信息可能分布在上下文的不同位置,模型需要能夠訪問任意位置的信息,作者稱這一特性為「可達性」。已有方法往往忽視了保持 token 之間可達性的重要性,可能導致信息傳遞受限,從而影響模型在長序列和復雜任務中的表現。
為解決這一問題,作者提出了一種即插即用的高效長文本上下文建模方法——關鍵上下文感知注意力機制(CCA-Attention),其特點如下:
- 高效長文本建模: 通過全局池化注意力與局部保留注意力的協同設計,在顯著降低計算量的同時保持對長距離依賴的建模能力。
- 線性計算復雜度: 通過引入 core token 聚焦關鍵上下文,大幅提高計算效率。
- 可即插即用集成:無需修改模型結構和從頭訓練,可以輕松集成到預訓練的 LLM 中,僅需少量微調即可實現性能優化。
對比 DeepSeek 發布的 NSA [8] 需引入額外的壓縮模塊并從頭訓練 LLMs,CCA-Attention 無需引入額外參數和修改模型結構,可以無縫替換現有 LLMs 中的標準自注意力模塊。對比月之暗面發布的 MoBA [9] 通過門控機制丟棄不相關塊,CCA-Attention 通過動態聚合關鍵上下文為核心 token 的方式,在降低計算量的同時,確保所有 token 的信息交互,保留了完整的全局建模能力。
CCA-Attention:革新性的解決方案
圖 2: CCA-Attention 示意圖
全局感知池化:降低計算維度的智慧之舉
標準自注意力計算量隨序列長度呈平方級增長,長序列處理計算開銷極大。大量研究發現注意力權重的分布并不均勻,絕大部分注意力權重被分配給了少數重要 token,其余部分貢獻有限,屬于冗余上下文。
受此啟發,作者提出全局感知池化模塊。具體而言,將輸入序列,分成互不重疊的
個組,g 為分組大小。對于第 i 組
,使用該組最后一個 token
的 query 向量與組內所有 token 的 key 向量計算重要性分數,并獲得該組核心
:
其中,是第 i 組
的最后一個 token 對應的 query 向量,
是第 i 組的 key 矩陣,
和
是可學習的參數。將各組 core token 拼接起來得到 core token 序列
。
為減少冗余,作者使用 core token 序列代替原始 token 進行注意力計算,將維度從
降至
,從而降低了計算和存儲復雜度。通過 core token 序列計算得到的鍵值矩陣表示為:
其中 和
是可學習參數。
局部保留模塊:捕捉局部依賴的關鍵
盡管全局感知池化模塊能有效捕捉長距離依賴,但由于其壓縮特性,可能會忽略細粒度的局部上下文,而這些局部語義對于語言建模同樣至關重要。為此,作者進一步提出局部保留模塊(Locality-preserving Module),為全局模塊提供有效互補信息。
具體來說,該模塊會確保每個 token 都能至少關注前面 w 個原始 token,以此來捕捉局部上下文信息,保留連續性語義信息:
為了應對生成過程中標記數量難以維持為組大小 g 的整數倍的問題,作者將局部窗口大小設置為,確保注意力窗口與組大小對齊,避免信息遺漏;
是原始 token 序列經過線性變換后的鍵值矩陣。
局部保留模塊與全局池化模塊共享線性變換參數,不會引入額外參數開銷。在實際推理中,局部模塊提供精細語義支持,彌補全局壓縮帶來的信息損失,共同構成完整的上下文建模體系。
全局-局部模塊可微融合:打造全面可達性的橋梁
全局感知池化模塊和局部保留模塊在計算時都只涉及部分 token,導致注意力的可達性有限。為解決這個問題,作者采用全局-局部模塊可微融合策略。具體而言,該策略將兩種注意力模塊中的鍵值矩陣進行組合,形成統一的鍵矩陣和值矩陣
。由此,CCA-Attention 的最終輸出表示為:
其中,每個位置的輸出計算表達式如下:
基于 Triton 的底層加速:提升效率的強大動力
為了在訓練、預填充、解碼期間實現 FlashAttention 級別的加速,作者基于 Triton 實現了硬件對齊的 CCA-Attention 內核。作者借鑒 FlashAttention 的設計思路,利用 Triton 進行底層算子融合,將全局池化注意力和局部保留注意力整合為一個獨立且緩存友好的算子,有效消除冗余計算,并原生支持 KV 緩存技術,進一步提升訓練、預填充、解碼階段的計算效率。相比標準自注意力機制,CCA-Attention 在計算復雜度和 KV 緩存內存占用方面具有顯著優勢,從而在整體上實現了更快的運行速度與更高的內存利用效率。
實驗結果
實驗設置
作者將 CCA-Attention 應用于 LLaMA2-7B-32K 和 LLaMA2-7B-80K 模型,并在 SlimPajama 數據集上微調 1,000 步。對比方法包括 StreamingLLM、LM-Infinite 和 MInference 等高效注意力方法。評估指標涵蓋 LongBench 基準測試和多文檔問答準確匹配得分(EM Score)等,全面衡量模型在長文本任務中的性能表現。
長序列語言建模
在 LongBench-E 基準測試中,CCA-LLM 取得了最高的平均得分。以 LLaMA2-7B-32K 模型為例,其得分顯著優于 LM-Infinite 和 MInference;在 LLaMA2-7B-80K 模型上,CCA-Attention 依然表現出色,平均分數與標準自注意力相當,同時推理延遲和顯存占用大幅降低,展現出更強的長序列處理效率優勢。
表 1: 長序列語言建模實驗
長文檔問答任務
在多文檔問答任務的 EM Score 評估中,CCA-LLM 在不同序列長度下均展現出優異的表現,且其性能優勢隨著上下文長度的增加而愈加明顯。在處理超長上下文(如 64K 和 128K)任務時,CCA-LLM 的 EM 得分超越了標準自注意力機制,同時推理速度也顯著提升——在 128K 上下文長度下,推理速度達到標準自注意力方法的 7.9 倍,展現出其在高效長文本建模方面的突出優勢。
表 2: 長文檔問答實驗
計算和存儲效率對比
相比標準自注意力及其他高效注意力方法(如 MInference),CCA-Attention 在推理速度與內存占用方面展現出顯著優勢。不同于 MInference 等僅關注預填充(prefilling)階段加速的方法,CCA-Attention 能夠同時優化預填充和解碼(decoding)兩個階段,實現端到端的全流程高效推理。
在 64K 上下文長度下,CCA-Attention 的推理速度達到標準自注意力的 5.7 倍,KV Cache 顯存占用也大幅降低;在 128K 上下文任務中,推理速度提升更是達到 7.9 倍,同時 KV Cache 顯存使用減少高達 93%,充分體現了其在長序列建模中的高效性與實用性。
圖 3: 內存與計算效率對比
總結
作者提出了一種面向長序列建模的關鍵上下文感知注意力機制(CCA-Attention)。相比標準自注意力,在保持模型性能的前提下,CCA-Attention 顯著降低了計算開銷。
該方法由兩個互補模塊構成:
- 全局感知池化模塊:基于輸入 token 的重要性提取核心 token(core token),用于后續注意力計算,從而高效捕捉全局粗粒度的信息;
- 局部保留模塊:聚焦于鄰近 token 的細粒度上下文信息,作為對全局池化模塊的有效補充。
實驗結果表明,CCA-Attention 在多種長文本任務中表現出色,同時顯著提升了計算效率,具備良好的實用性與可集成性。