MHA -> GQA:提升 LLM 推理效率
一、背景
我們在之前的文章中詳細分析過 GQA 相比 MHA 的推理優勢(省顯存、計算強度高),不過 GQA 有可能導致精度的損失,因此早期的一些不太大的 LLM 會使用 MHA。針對這個問題有兩種優化思路:
- 將 MHA 轉換為 GQA,長短序列都適用。
- 在長序列場景使用 Token 稀疏化方案或者結合投機采樣策略。?
本文中我們介紹一個將 MHA 轉換為 GQA 的工作,不過論文的實驗還偏少,效果也不是非常好;此外,最新的模型基本都在預訓練階段默認采用 GQA(LLaMA3 8B、LLaMA3.2 3B 以及 Microsoft 的 Phi 系列模型等),降低了本文工作的應用場景。
對應的論文:[2412.20677] Align Attention Heads Before Merging Them: An Effective Way for Converting MHA to GQA [1]
相關工作也可以參考我們以前的文章:
- ??微軟 RetrievalAttention: LLM+ANN, LLM 推理速度與精度的平衡??
- ??LLM 推理的 Attention 計算和 KV Cache 優化:PagedAttention、vAttention 等??
二、摘要
LLM 在多種自然語言處理任務中展現出卓越性能。然而,隨著模型規模與輸入序列長度的增長,KV Cache 的急劇膨脹顯著拖慢了推理速度。鑒于此,作為 MHA 的替代方案,GQA 已被廣泛引入 LLM。本研究提出了一種低成本方法,可將 MHA 模型按任意 KV Head 壓縮比修剪為 GQA 模型。
該方法基于 L0 掩碼逐步剔除冗余參數。此外,在不改變模型的前提下,對注意力頭施加正交變換,以在修剪訓練前提升 Attention Head 間的相似度,從而進一步優化模型性能。本方法兼容RoPE,意味著訓練后的模型能完全適配主流標準 GQA 框架。實驗表明,僅通過監督微調,提出的策略即可將 LLaMA2-7B 模型的 KV Head 壓縮高達 87.5%,且性能損失極小。
三、引言
如下 3.1 和 3.2 部分在我們之前的文章中有相吸介紹:???LLM 推理的 Attention 計算和 KV Cache 優化:PagedAttention、vAttention 等??。
3.1 MHA Attention 計算
如下圖所示為標準的 LLM Decoding 階段的 Multi-Head Attention(MHA)計算,其中的 D 表示 hidden size,H 表示 Head 個數,L 表示當前是在序列的第 L 個 Token。可以看出:
- 當Batch Size 為 1時,圖中紅色、綠色、藍色處的矩陣乘法全部為矩陣乘向量,是明顯的 Memory Bound,算術強度不到 1。
- 當Batch Size 大于 1時(比如 Continuous Batching):
- 紅色和藍色部分:因為是 Weight 乘以 Activation,所以不同的 Request 之間可以共享 Weight。這里變成矩陣乘矩陣,并且 Batch Size 越大,算術強度越大,也就越趨近于 Compute Bound(FFN 層也類似)。
- 綠色部分:這里 Q、K 和 V 的 Attention 計算,是 Activation 乘以 Activation,所以不同的 Request 之間沒有任何相關性。即使 Batching,這里也是Batched 矩陣乘向量,并且因為序列長度可能不同,這里不同 Request 的矩陣乘向量是不規則的。也就是說,這里算術強度始終不到 1,是明顯的 Memory Bound。
從上可以看出,通過 Continuous Batching 可以很好的將 Memory Bound 問題轉變為 Compute Bound,但 Q、K 和 V 的 Attention 計算的算術強度卻始終小于 1。根據 Amdahl 法則,如果系統中有一部分無法優化,即使把其他部分優化到可以忽略,不可優化的部分也會決定整個系統的性能上限。不幸的是,Sequence Length 越長,這里的計算量就越不可忽略。
根據模型配置信息可以估算出模型中 Q、K 和 V 的 Attention 計算與其他矩陣計算的比例大約為 (L+D)/(12*D)(PS:準確值需要根據具體的模型參數計算)。也就是說,當序列長度 L 等于 12 倍的 hidden size 時,兩部分的計算量相當,即使其他矩陣計算優化到 0,加速比也只有 2x。比如 LLaMA 2 7B 的 hidden size 為 4K,當序列長度達到 44K 時,兩部分的計算量相當,要優化的重點也會很不一樣,這也是很多長序列相關工作會在 Attention 部分采用稀疏 Attention 的一個重要原因。
3.2 GQA Attention 計算
早期通常只有比較大的模型才會采用 GQA([2305.13245] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints),比如 LLaMA -2 70B,而 LLaMA-2 7B/13B 都沒有采用 GQA。然而,LLaMA-3 8B 中也用上了 GQA,甚至其他更小的模型也在將 MHA 替換為 GQA。
- 使用 GQA 有個非常大的好處:在推理階段可以顯著降低 KV Cache 的大小,比如,相比 32 個 KV Head 的 MHA,32 個 Query Head,8 個 KV Head 的 GQA 的 KV Cache 大小可以降低到 MHA 的 8/32=1/4,這也為更大的 Batch Size 提供了空間,可以進一步提升吞吐。
- 除此之外,還有一個比較大的好處:可以明顯提升 Q、K 和 V 的 Attention 計算的算術強度。此時雖然不同的 Request 之間同樣不能共享,但是同一個 Request 中的不同 Head 可以共享,比如 4 個 Query Head 共享 1 個 KV Head,則算術強度就會接近于 4,也可以更充分發揮 Tensor Core 的算力。
使用 MHA 時,Q、K 和 V 的 Attention 計算可以使用 CUDA Core 也可以使用 Tensor Core。由于 Tensor Core 要求矩陣的 Shape 是 8 的整數倍,如果不滿足就只能 Padding:
- 對于MHA而言,其是矩陣乘向量,則有7/8 的計算是冗余的。
- 對于GQA而言,如果 4 個 Query Head 共享 1 個 KV Head,則 Attention 計算有 4/8 的計算是冗余的,如果8 個 Query Head 共享 1 個 KV Head,則沒有計算的冗余。很多框架已經做了相關優化,比如 LMDeploy,TRT-LLM 的 XQA 等。
- 此外,PagedAttention 的 KV Cache 是非連續存儲的,導致即使使用 GQA 也無法利用 Tensor Core。
PS:對于 GQA 而言,理論上也可以期望 GPU 的 L2 Cache 能夠緩存到共享的 Key 和 Value Cache,從而緩解 IO Bound 問題,然而實際上無法人為控制,不一定能達到理想的效果。
3.3 動機
作者從 C4 訓練集采樣了 128 個 Sequence,共 128*2048=262144 個 Token,評估了 LLaMA2-7B 模型中每個 Transformer Block 中 Attention Head 的 KV Cache 的相似性。
如下圖 Figure 2 所示,分析發現,大多數 Head 之間的 KV Cache 幾乎是正交的,僅有少數 Head 共享較高的相似度。這表明直接對投影矩陣進行均值化會導致性能顯著下降,說明 Attention Head 之間存在重要的獨特性。
根據之前 [2406.07056] Effectively Compress KV Heads for LLM [2] 的研究,KV Cache 的低秩性為優化提供了新思路:
- 可通過正交變換對齊 Key 和 Value 的投影矩陣。
- 這種方法降低了優化的難度,并為 MHA 轉換為 GQA 提供了理論支持。
四、方案
4.1 網絡轉換
主要目的是:在剪枝訓練之前,對模型進行轉換,以增加同一組內不同 Attention Head 之間的相似性,從而提高模型優化的效率。具體的過程大概為:
- 根據前述的方案,使用部分 C4 的訓練集來收集相應的 KV Cache。
- 基于余弦相似性或者歐氏距離,計算最優的正交矩陣。
- 將計算得到的正交矩陣融合到對應的 Q、K、V 投影矩陣中,保證計算不變性。對于 Q 和 K 的投影矩陣,要考慮 RoPE 的場景,在子空間應用正交變換。
通過正交變換,可以使得同一組內不同 Attention Head 在特征空間中更加接近,從而在后續的剪枝訓練過程中更容易找到合適的參數共享方式,提高模型的壓縮效果和性能。
如下圖 Figure 3 所示,作者展示了不同的 Block 中轉換前和轉換后的 KV Cache 相似性,可以看出,轉換后相似性明顯增加:
4.2 找到更好的分組方法
在獲取了每對 Attention Head 之間的相似度評分后,可依據這些評分對 Attention Head 進行重新分組。將一個組的相似度評分定義為該組內每對 Attention Head 之間相似度評分的總和,而每種分組結果的總相似度評分則是所有組相似度評分的累加。
合理的分組方式可以使得同一組內的 Attention Head 在特征空間中更加相似,從而在剪枝時更容易找到合適的參數共享方式,提高模型的壓縮效果和性能。
4.3 剪枝訓練
主要目的是:通過剪枝訓練,逐步將原始的 KV Head 轉移到新的 KV Head 上,同時保持模型性能。如下圖 Figure 1 所示,具體過程包括:
- 添加新的投影矩陣:在每組內使用 Mean Pooling 初始化新的投影矩陣。
- 應用 L0 掩碼:引入 L0 掩碼來控制原始 KV Head 和新 KV Head 之間的轉換。初始時,掩碼值為 1,表示使用原始 KV Head;在剪枝過程中,逐步將掩碼值約束為 0,表示使用新的 KV Head。
- 知識蒸餾:使用 KL 損失和 BiLD 損失,鼓勵學生模型與教師模型的輸出對齊,從而保持模型性能。
五、實驗評估
如下圖所示,作者在多個任務上進行評估,GQA-16(32 個 KV Head 變為 16 個) 時平均精度甚至有所提升。但是 GQA-8(壓縮 4x)和 GQA-4(壓縮 8x)時損失就比較大:
六、參考鏈接
本文轉載自 ??AI閑談??,作者: AI閑談
