單卡A100實現百萬token推理,速度快10倍,這是微軟官方的大模型推理加速
大型語言模型 (LLM) 已進入長上下文處理時代,其支持的上下文窗口從先前的 128K 猛增到 10M token 級別。
然而,由于注意力機制的二次復雜度,模型處理輸入提示(即預填充階段)并開始產生第一個 token 可能需要幾分鐘時間。導致首個 token 生成的時間過長,從而嚴重影響了用戶體驗,這也極大地限制了長上下文 LLM 的廣泛應用。
舉例來說(如圖 2a 所示),在單臺裝有 A100 的機器上為 LLaMA-3-8B 提供服務時,如果提示有 30 萬個 token,模型需要 6 分鐘才能完成預填充( pre-filling)階段,如果提示增加到 100 萬個 token,這個數字將增加到 30 分鐘。
自注意力計算的開銷占到了總預填充延遲的 90% 以上,這使其成為 LLM 處理長上下文時的主要瓶頸。現有的加速預填充方法在應用于長上下文 LLM 時通常無法保持可接受的準確性或效率。
為了解決上述問題,來自微軟、薩里大學的研究者提出了一種旨在加速長序列處理預填充的稀疏計算方法:MInference( Milliontokens Inference )。
- 論文地址:https://arxiv.org/pdf/2407.02490
- 論文主頁:https://hqjiang.com/minference.html
- 論文標題:MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention
MInference 可以直接應用于現有 LLM,無需對預訓練設置進行修改或額外的微調。
通過對各種下游任務(包括 InfiniteBench、RULER、PG-19 和 Needle In A Haystack)以及模型(包括 LLaMA-3-1M、Yi-200K、GLM-4-1M、Phi-3-128K 和 Qwen2-128K)進行評估,實驗證明 MInference 可有效將 A100 上的預填充推理延遲降低多達 10 倍,同時保持準確性。
使用 MInference 1.0 ,長上下文 LLM(如 LLaMA-3-8B-1M、GLM-4-1M)在單個 A100 上的推理速度實現了 10 倍提升,并且準確度更高。
方法介紹
作者提出了 MInference,這個名字反映了他們希望在一臺 A100 機器上實現百萬(million)token 推理的雄心。
MInference 是一種無需訓練的高效方法,用于基于動態稀疏注意力的長上下文 LLM 的預填充階段。
研究者認為注意力,特別是在長上下文中,是稀疏和動態的,即在不同的輸入中,稀疏模式有很大的不同。這種動態稀疏性呈現出三種適用于所有輸入的獨特空間聚合模式:A 形(A-shape)、垂直 - 斜線(Vertical-Slash)和塊狀 - 稀疏(Block-Sparse)。
MInference 首先使用內核感知稀疏模式搜索算法為每個頭部離線確定最佳動態稀疏模式,如算法 1 所示。在推理過程中,它會根據頭部的模式動態逼近動態稀疏指數,如算法 2、3 所示。最后,作者使用優化后的 GPU 內核執行高效的動態稀疏注意力計算,大大減少了長上下文 LLM 的預填充階段延遲。
例如,對于「垂直 - 斜線」模式,作者首先利用最后一個 Q 和 K 之間的注意力計算來估計垂直線和斜線的最佳指數。然后,他們利用動態稀疏編譯器 PIT 和 Triton 構建垂直 - 斜線 FlashAttention 內核,加速注意力計算。對于 A 形、垂直 - 斜線和塊狀 - 稀疏模式,作者首先在注意力計算中使用 Q 和 K 的均值池。利用均值池和 MatMul 的交換屬性,可以估算出塊狀 - 稀疏指數。然后,他們使用 Triton 構建塊稀疏 FlashAttention 內核,加速注意力計算。有關內核的詳細實現,請參閱附錄 C.4 和代碼。
在長上下文基準中的評估結果
作者在一系列場景中測試了 MInference,包括 QA、編碼、基于檢索的任務、multi-hop QA、總結和數學任務。RULER 基準包括幾個復雜的 multi-hop 或 multi-needle 任務,有效地反映了 LLM 的實際上下文窗口大小。如表 1 所示,MInference 有效地保留了 LLM 的實際上下文窗口處理能力,甚至將實際上下文窗口大小略微擴展到 32K。
作者還使用平均 token 長度為 214K 的 InfiniteBench 在更廣泛的任務中測試了 MInference,如表 2 所示。與 SoTA 基線相比,MInference 在所有任務中都始終保持了良好的性能。值得注意的是,在更具挑戰性的檢索任務(如 KV 檢索任務)中,所有基線都無法做出準確預測,準確率低于 1.2%。但是,MInference 成功地保留了處理動態 KV 對檢索的能力。
為了進一步評估不同上下文長度和關鍵信息在提示中不同位置時的性能,作者使用「大海撈針」任務測試了各種模型和方法。如圖 1 所示,MInference 在不同的模型、上下文窗口和提示信息位置下都表現良好,與原始模型相比,其性能保持不變甚至略有提高。在 LLaMA-3-8B 和 GLM-4-9B-1M 的情況下,MInference 在高達 1M 的上下文窗口中實現了完全綠色的性能。相比之下,即使在 70K 上下文窗口中,StreamingLLM 和 InfLLM 在提示的中間段性能也會下降到 20% 以下。
作者還使用 PG-19 在語言模型任務中測試了 MInference,其中包括多達 100k 的 token。如圖 2 所示,MInference 有效地保持了 LLaMA-3-8B 和 Yi-9B-200K 的困惑度,而所有基線都出現了不同程度的困惑度下降。此外,與標準的 StreamingLLM 相比,使用膨脹和步長配置的 StreamingLLM 更好地保持了困惑度性能。
延遲和內核中的稀疏模式
圖 3 展示了本文提出的三種注意力模式以及 FlashAttention 的微基準測試結果。可以看出,Vertical-Slash 是三種模式中最慢的,但在 1M 上下文窗口下,相比 FlashAttention 仍然實現了 13 倍的加速。
圖 4 展示了 Vertical-Slash 頭部內核中的稀疏索引。垂直線通過 PIT FlashAttention 使用 1x64 塊計算,而斜線通過塊級 FlashAttention 使用 64x64 塊計算。