MiniCache 和 PyramidInfer 等 6 種優化 LLM KV Cache 的最新工作
一、背景
在 LLM 推理中,常常會采用 KV Cache 來緩存之前 Token 的中間結果,以顯著減少重復計算,從而降低自回歸生成中的延遲。然而,KV Cache 的大小與序列長度成正比,在處理長序列時會面臨極大的挑戰。尤其當前許多模型開始支持幾百 K 甚至幾 M 的序列長度,進一步凸顯了 KV Cache 的問題,因此很多研究工作致力于降低 KV Cache 的占用。
本文中簡單介紹幾個最新的工作,包括 SnapKV、YOCO、CLA、Layer-Condensed KV Cache、MiniCache 以及 PyramidInfer,它們都試圖降低緩解 KV Cache 的壓力。關于 GQA、MQA、DeepSeek MLA 以及量化相關的工作我們已經在之前進行了介紹,這里不再贅述。
二、KV Cache 大小
KV Cache 的大小與模型配置(層數,hidden_size,Attention head 個數等)以及序列長度、Batch Size 成正比。其中單個 Token 對應的 KV Cache 大小與模型配置相關,并且是固定的,這里將其稱為單位 KV Cache 計算公式為:
sum_token = (hidden_size / num_attention_heads * num_key_value_heads) * num_hidden_layers * 2(k, v)
而總的 KV Cache 大小為:
sum = sum_token * seq_len * batch_size
batch_size 和 seq_len 越大,KV Cache 越大,如下圖所示為 LLaMA2-7B 模型的 batch_size 和 seq_len 對應的 KV Cache 大小(默認 FP16 精度):
- 當 batch_size * seq_len 為32K時,比如 batch_size 為 1,seq_len 為 32K,其 KV Cache 大小為16GB,甚至超過模型權重大小 14GB。
- 當 batch_size * seq_len 為128K時,比如 batch_size 為 1,seq_len 為 128K,其 KV Cache 大小為 64GB,加上模型權重 14GB 甚至快要超過 A100 GPU 的 80GB 顯存限制。?
三、SnapKV
[2404.14469] SnapKV: LLM Knows What You are Looking for Before Generation 的核心思路比較簡單,如下圖 Figure 1 所示,在 Prefill 階段不是保留所有輸入 Token 的 KV Cache,而是采用稀疏化的方式,針對每個 Attention Head 將 Prompt 分為 Prefix 和 Window 兩部分;然后,通過 Window 中 Token 與 Prefix 中 Token 的 Attention Score 來選擇稀疏化的 Token;最后,將它們的 KV Cache 和 Window 中 Token 的 KV Cache 一起作為 Prompt 的 KV Cache。需要說明的是:每個 Attention Head 中從 Prefix 里挑選的 Token 可能不同。此外,Decoding 階段也不會再更新 Prompt 的 KV Cache。
SnapKV 在處理 16K Token 的輸入時,可以獲得 3.6x 的加速,內存效率提升 8.2x。同時在 16 個長序列數據集上保持了與基線模型相當的精度。此外,使用 Huggingface 可以在單個 A100-80GB GPU 上處理 380K 上下文 Token 的任務。
四、YOCO
在 [2405.05254] You Only Cache Once: Decoder-Decoder Architectures for Language Models 中,作者只保留一層全局的 KV Cache。這種設計可以大大降低 GPU 顯存的需求,加快 Prefill 階段。如下圖所示,YOCO 模型與常規 Decoder-Only LLM 的區別有幾點:
- 前 L/2 層(Self-Decoder)使用Efficient Self-Attention,實際上就是滑動窗口 Self-Attention或作者之前論文提出的Multi-Scale Retention。其只用保存窗口內的 KV Cache 即可。
- 第 L/2 層的 KV Cache 作為Global KV Cache。也就是只有一層有全局 KV Cache。
- 后 L/2 層(Cross-Decoder)使用Global Cross Attention,對應的 KV 為上一步的 Global KV Cache,也就是后續所有 L/2 層的 Cross Attention 的 KV Cache 都是相同的。?
五、CLA
[2405.12981] Reducing Transformer Key-Value Cache Size with Cross-Layer Attention 中作者同樣采用 Cross-Attention 機制來降低 KV Cache。不同的是作者并非采用固定層作為 Cross-Attention 的輸入,而是采用相鄰層,如下圖左圖所示。最簡單的方式就是隔層共享,稱作 CLA2,實際也可以每 3 層共享,稱作 CLA3,如下圖右圖所示。此外,這種方法與 MQA 和 GQA 等修改 Attention Head 的方案是兼容的。CLA2 顯存減小 2x,CLA3 顯存減小 3x。
作者訓練 1B 和 3B 參數模型模型實驗表明,CLA 相比傳統的 MQA 在顯存占用、準確性方面可以實現帕累托改進,從而實現更長的序列長度和更大的 Batch Size。(PS:但并不意味著可以優于現在廣泛采用的 GQA?)
六、Layer-Condensed KV Cache
在 [2405.10637] Layer-Condensed KV Cache for Efficient Inference of Large Language Models 中,作者同樣采用了僅計算和緩存少量層 KV Cache 的方案,從而顯著節約顯存消耗并提升吞吐量。如下圖 Figure 1 所示,僅保留最后一個 Transfomer Block 層的 KV Cache,當生成后續 Token 時其對應的 KV Cache 都從最后一層取。
七、MiniCache
在 [2405.14366] MiniCache: KV Cache Compression in Depth Dimension for Large Language Models 中,作者觀察到 KV Cache 在 LLM 中的深層部分的相鄰層之間表現出了高度相似性,可以基于這些相似性對 KV Cache 進行壓縮。此外,作者還引入了 Token 保留策略,對高度不同的 KV Cache 不進行合并。并且這種方法可以與其他的 KV Cache 量化方案正交使用。
作者在 LLaMA-2、LLaMA-3、Phi-3、Mistral 和 Mixtral 等模型上進行實驗,在 ShareGPT 數據集上,采用 4 Bit MiniCache LLaMA–7B 與 FP16 全量 KV Cache 相比實現了 5.02x 的壓縮比,推理吞吐提高約 5 倍,顯存占用減少 41%,同時性能幾乎無損。
如下圖 Figure 3 所示為其壓縮策略和保留策略:
如下圖 Figure A 所示為其詳細的執行流程:
- 1. 獲取 KV Cache:在 Prefill 階段,逐層生成 KV Cache。
- 2. 跨層合并:當到達合并開始層 S 時,將當前層 L 的 KV Cache 與前一層 L-1 的 KV Cache 進行合并,以減少冗余。
- 3. 緩存:將合并后的 KV Cache 存儲起來,以便將來使用。
- 4. 刪除:在 Decoding 階段,刪除不必要的或冗余的 KV Cache,以優化內存使用。
- 5. 加載和生成:獲取所需的 KV Cache,用于生成輸出。
- 6. 恢復:對獲取的 KV Cache 應用誤差抑制機制,包括 rescaling 和 retention recovery,以最小化合并和壓縮過程中引入的誤差。
- 7. 更新:在恢復階段后,使用最終的 KV Cache 更新共享的 KV Cache。
八、PyramidInfer
在 [2405.12532] PyramidInfer: Pyramid KV Cache Compression for High-throughput LLM Inference 中,作者發現影響未來生成的關鍵 KV 的數量逐層減少,并且可以通過注意力權重的一致性來提取這些關鍵 KV。基于這些發現,作者提出了 PyramidInfer,通過逐層保留關鍵上下文來壓縮 KV Cache。PyramidInfer 在不犧牲性能的情況下計算更少的 KV,并節約大量顯存。實驗結果表明,與 Accelerate 相比,PyramidInfer 的吞吐提高了 2.2 倍,KV Cache 的顯存占用減少了 54% 以上。
如下圖 Figure 2 所示為 PyramidInfer 與 StreamingLLM 和 H2O 的區別,PyramidInfer 中 KV Cache 會逐層遞減,越往后越稀疏(PS:如果是這樣,那么 Layer-Condensed KV Cache 中只保留最后一層的方案是不是不太合理):
PyramidInfer 的執行過程如下圖 Figure 6 所示:
- 在 Prefill 階段,PyramidInfer 只保留每層的關鍵上下文(Pivotal Context, PvC)來壓縮 KV Cache。
- 在 Decoding 階段,PyramidInfer 根據新的最近的 Token 來更新 PvC。?
如下圖 Table 1 所示,PyramidInfer 在使用更少 KV Cache 的情況下獲得更快的推理速度:
如下圖 Figure 11 所示,作者進一步測試了 PyramidInfer 在更多 Batch Size 下的表現,其在比較小 Batch Size 時幾乎沒有加速,主要是因為減少 KV Cache 還需要一些額外的計算;而在比較大的 Batch Size 能獲得更大的加速比。而 Full Cache 當 Batch Size 大于 32 吞吐反而降低:(PS:這個降低不太符合預期,通常來說隨著 Batch Size 的增加,計算密度會更高,相應的吞吐也應該更高,而且在 32 左右還遠沒有到 Compute Bound)。
九、參考鏈接
- ??https://arxiv.org/abs/2404.14469??
- ??https://arxiv.org/abs/2405.05254??
- ??https://arxiv.org/abs/2405.12981??
- ??https://arxiv.org/abs/2405.10637??
- ??https://arxiv.org/abs/2405.14366??
- ??https://arxiv.org/abs/2405.12532??
本文轉載自 ??AI閑談??,作者: AI閑談
