美團 Flash Communication:LLM 推理的 AllReduce 通信優化 精華
一、背景
前段時間的文章里我們剛剛介紹過兩個對 LLM 分布式推理場景中 AllReduce 的優化工作,一個是 NVIDIA TensorRT-LLM 中的 MultiShot 無損優化,另一個是 Recogni 提出的基于量化壓縮實現的 AllReduce 加速方案。本文中我們繼續介紹美團新發表的 AllReduce 量化壓縮優化方案。
對應的論文為:[2412.04964] Flash Communication: Reducing Tensor Parallelization Bottleneck for Fast Large Language Model Inference [1]
二、摘要
隨著 LLM 規模的不斷增長,快速推理所需的分布式解決方案往往要利用多維并行性,將計算負載分散至 GPU 集群的多個設備上。然而,此方法往往會引入顯著的通信開銷,尤其在帶寬受限的設備上(比如沒有 NVLink 或者跨機的情況)。
本文中作者提出 Flash Communication,一種新穎的低比特壓縮技術,旨在緩解推理過程中 Tensor Parallelism(TP)的通信瓶頸。作者在多種最新的 LLM 上進行的廣泛實驗,驗證了該方法的有效性。該方法可以將節點內通信速度提升 3 倍,并將首 Token 時間縮短 2 倍,同時幾乎不犧牲模型精度。
PS:上述的結論其實有點夸大,INT4 可以實現上述速度,但精度損失還是有點大的;而 INT8 可以保持精度,但加速比又沒這么多。
三、引言
3.1 硬件拓撲
作者論文中評估主要采用了兩種機型,一種是 8 x L40 GPU 節點,如下圖 Figure 12 所示。每個節點上有 8 個 L40 GPU,每 2 個 GPU 在一個 PCIe Switch 下,沒有 NVLink + NVSwitch,并且每個節點只有 1 個 100 Gbps 的 NIC。因此:
- 如果節點內的 TP 通信都需要走 PCIe 鏈路,如果是不同 CPU Socket 下的 GPU 通信,還需要通過 CPU 之間的 UPI,因此通信效率可能比較低。
- 如果節點間通信,則必須通過節點的 NIC,最糟糕的情況是左側紅框和右側紅框的 GPU 組成 TP 組進行通信。
- PS:本文中作者并沒有涉及節點間通信,甚至 L40 上的 TP=8 的 8 GPU 通信都沒有。?
而 A100 節點類似下圖所示,節點內有 8 個 A100 GPU,這些 GPU 通過 NVLink + NVSwitch 實現全互聯,任何兩個 GPU 之間的通信帶寬都可以達到 600 GB/s。不過作者介紹節點間通信帶寬是 200 Gbps,那么可能節點上就沒有紅框中的 NIC,只有藍框中的 200 Gbps NIC。(PS:論文中也不涉及節點間通信)
此外,A100 GPU 有 108 個 SM,而 L40 GPU 有 108 個 SM。
3.2 ReduceScatter + AllGather
我們在之前的文章中詳細介紹過 AllReduce,這里再簡單陳述一下。對于常見的基于 Ring 的 AllReduce 實現中,通常將一個 AllReduce 操作拆分為一個 ReduceScatter 和一個 AllGather 操作,如下圖所示:
具體的 ReduceScatter 操作如下,每個設備(GPU)發送一部分數據給下一個設備,同時接收上一個設備的數據并累加。這個過程執行 K-1 步,ReduceScatter 后每個設備都包含一部分數據的 Sum:
具體的 AllGather 操作如下,每個設備(GPU)將其持有的部分結果發送給下一個設備,同時接收上一個設備的部分結果,逐步匯集完整的結果,同樣需要 K-1 步。AllGather 后,每個設備都包含全量的數據:
NVIDIA 在 3x Faster AllReduce with NVSwitch and TensorRT-LLM MultiShot | NVIDIA Technical Blog [2] 中并沒有介紹 ReduceScatter 的優化,不過在我們推測其可能采用了下述的優化方式:具體來說,Ring ReduceScatter 可以等效為一個 All2All 操作實現數據的重排,然后在 Local 進行 Reduce 操作(或者 NVSwitch 上進行 Reduce 操作)。此過程只有一個 All2All 的整體通信操作,雖然實際上與 Ring 實現的方式的通信量和計算量沒有變化,但可以避免 K-1 個 Ring Step 的同步,進而可以有效降低時延。
3.3 TP 推理
如下圖 Figure 3 所示,對于 LLaMA 模型推理,其一個 Transformer Layer 需要 2 次 AllReduce 通信,不過需要 Attention 以及 FFN 都采用先列切再行切的方式。以 80 層的 LLaMA 3 70B 模型為例,一次 Forward 需要 180 次 AllReduce 通信。
3.4 延遲分析
如下圖 Figure 2 所示,作者在 4*L40 GPU 上測量了 LLaMA-3-70B 模型在不同序列長度下各個部分的開銷(Batch Size 為 8),可以看出,序列越長,AllReduce 的通信占比越大,在 4K 序列長度時 AllReduce 通信開銷為 18% 左右,在序列長度達到 32K 時,通信開銷占到 40% 左右。
在 A100 GPU 上雖然有 NVLink+NVSwitch 互聯,最大的通信開銷依然可以達到 20%(PS:不過作者這里沒有提供詳細的數據)。
四、方案
4.1 量化挑戰
為了在準確性與時延之間達成最佳平衡,作者選擇采用低比特量化技術。如下圖 Figure 4 所示,可觀察到,在大 Block 下進行逐 Token 量化會導致 C4 困惑度的性能急劇下降,非對稱量化(Asym)相對較好,不過依然下降明顯,因此細粒度量化是必要的。
然而,作者發現在此情境下應用低比特激活量化并非易事。因此,作者計算了 LLaMA-3-8B 模型在激活量化前后的層級均方誤差(MSE)來研究量化的敏感性。如下圖 Figure 5 左圖所示,下投影 dproj 的量化難度遠高于輸出投 影oproj。
此外,All-Reduce 中 Reduce-Scatter 和 All-Gather 操作對應的量化難度也各不相同,如下圖 Figure 5 右圖所示。這一現象符合預期,因為 Reduce-Scatter 前的量化僅引入舍入誤差,而在 All-Gather 中,則同時包含舍入誤差和累積誤差。作為替代方案,可以在 All-Gather 操作前采用更高精度的量化以提升準確性。
4.2 通信算法
鑒于上述問題,作者設計了一種兩步量化策略以替代傳統的 Ring AllReduce 方法,稱為 Flash AllReduce,如下圖 Figure 6 所示。該策略與 TP 的結合如上圖 Figure 3 所示。
如下圖 Figure 6 展示了本文 Flash Communication 的通信原理:
- 首先,將每個 GPU 上的激活值按 Rank 的數量進行劃分。
- 在激活值上進行細粒度量化后,執行All2All 通信(與我們猜測的 TRT-LLM 的 MultiShot 實現類似),使得每個設備接收其規約所需的計算負載。當然,接收后也需要反量化操作。
- 在設備內完成 Reduce,不涉及通信操作。
- 對得到的結果再次進行量化以加速傳輸。然后進行AllGather以匯總所有結果,并在每個設備上進行反量化以恢復浮點數值。?
具體的算法過程也可以參考如下圖 Algorithm 1:
4.3 Kernel 設計
為了提升效率,作者開發了一個融合的 Flash AllReduce Kernel,以囊括上述所有集合通信操作及量化操作。如下圖 Table 1 所示,相比 Ring AllReduce 操作,Flash AllReduce 將量化-反量化步驟從 N 次減少到 2 次,Reduce-Gather 步驟從 N-1 次縮減到 1 次。盡管總體數據個數保持不變,但每一份數據均被量化到較低位數,從而大幅減少了傳輸的數據大小。
快速細粒度量化:每個節點的總通信量(個數) M 被劃分為 T 個 Chunk 進行傳輸。給定 Chunk 大小 C,如下圖 Figure 7 展示了 GPU 線程如何并行組織以處理 Chunk 信息。一個 Chunk 被分割成 N 個 Block,每個 Block 對應 32 個 Warp,其中每個 Warp 由 32 個 Thread 組成,每個 Thread 可處理 8 個 FP16 元素。以采用 128 組大小的非對稱量化為例,使用 16 個線程對每組 128 個元素進行量化。具體而言,利用 CUDA API 函數 __shfl_xor_sync 通過迭代交換這些 Warp Thread 間的信息,高效實現 Max/Min 歸約。
快速通信。不再使用 All2All 原語,而是利用 CUDA Runtime API 中的 GPU Peer Direct Memory 訪問來傳輸量化后的數據量,在此過程中,能夠直接從不同 Rank 獲取數據,顯著提升通信速度。
快速反量化。一旦接收到量化后的數據,需要將其反量化為 FP16 以進行 Reduce Sum。由于單純的 INT4 到 FP16 轉換會產生開銷,作者采用了 [2211.10017] Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production [3] 中的反量化布局。為了在線協調其順序,作者還采用了來自 LMDeploy 的快速 INT4 打包,如下圖 Figure 8 所示:
- 給定兩個 32 位無符號整數 U0 和 U1,它們分別持有 4 個 INT4 量化的激活值(每個存儲在 8 位中的低 4 位)用于傳輸。
- 首先執行右移 12 位操作,然后對其自身進行按位或運算。
- 隨后,使用 CUDA Math API __byte_perm 從這兩個整數中選擇目標位。通過這種方式,可以方便的按順序打包 8 個 4 位整數進行反量化。
- 接下來,應用 lop3.b32 對打包變量執行邏輯操作(0xF0 & 0xCC)| 0xAA,應用掩碼 0x000F000F 和 0x64006400,然后減去 0x64006400,這有效地表示了 FP16 中的 W1 和 W0。
- 通過改變剩余 INT4 整數的掩碼,可以迭代進行反量化。?
INT6 量化:鑒于在 All-Gather 之前進行低比特量化會導致更大的損失,這里作者選擇采用 INT8 位寬,同時保持 ReduceSum 的 INT4 位寬,從而有效構建了一個 INT6 解決方案。INT6 配置在性能與通信效率之間達到了很好的平衡。
五、實驗 & 結果
5.1 實驗配置
實驗在前述的 L40 和 A100 GPU 進行,對應的輸入 Token 為 1024,輸出為 64,基線為 FP16 通信。
5.2 精度對比
5.2.1 FP16 Weight 實驗
如下圖所示,作者針對 LLaMA-2 和 LLaMA-3 系列模型使用 FP16 Weight 進行了評估,可以看出,大部分情況下 Asym INT8 的損失都很小,基本無損(紅框);Asym INT6(INT8 + IINT4)在 LLaMA-2 損失較小,在 LLaMA-3 損失稍微有點大;而 LLaMA-3 的 INT4 方案損失比較大,這也與 [2411.04330] Scaling Laws for Precision [4] 的結論相符,LLaMA-3 用了更多訓練數據,相應也更難量化):
PS:需要說明的是,上述中我們沒有使用論文表格是因為論文中出現了嚴重錯誤,上述表格中:
- AVG 列:作者論文中計算的均值,如下圖 Table 2 所示,此結果計算有誤。
- New_AVG 列:我們自己根據表格中相關數據計算的均值。
- INT8_Weight_AVG:來自下述 Table 3 中對應 INT8 Weight 推理的均值。可以看出 INT8 Weight 的均值也和我們計算的 FP16 Weight 的結果均值接近,符合預期。?
5.2.2 INT8 Weight 實驗
如下圖 Table 3 所示,作者同樣針對 LLaMA-2 和 LLaMA-3 系列模型使用 INT8 Weight 進行了評估,和上述 FP16 Weight 結論基本類似:
5.3 Flash AllReduce vs Ring AllReduce
在集成了一系列優化技術后,Flash AllReduce 的速度顯著由于 Ring AllReduce。如下圖 Figure 10 所示,作者展示了通信量在 64MB - 1GB 時的通信時延。可以看出,其 INT4 版本最高可以實現 3.18x 的 Kernel 加速,而 INT6 在速度和精度之間取得了不錯的平衡(PS:需要注意的是,實際推理過程中通信量可能沒有這么大)。
如下圖 Figure 11 所示,作者也展示了不同 SM 數量對通信效率的影響。在通信量較小時,較少的 SM 數量更為有利,因為這可以減少 Kernel 啟動和 Block 間同步的開銷。然而,隨著通信量增大,計算需求增加,也很有必要使用更多 SM。配置 48 個 SM 可以在通信與計算之間達到了更佳的平衡。
5.4 時延和吞吐
如下圖 Figure 9 所示,作者也基于 LLaMA-3-8B 和 LLaMA-3-70B 模型在 L40 和 A100 上測量了 TTFT 的時延,可以看出,在 L40 上 TP=4 最多可以獲得 2.06x 的加速(對應的 INT4,INT8 只有 1.42x);而在 A100 上 TP=8 最多可以獲得 1.19x 加速(對應的 INT4,INT8 只有 1.1x)
如下圖 Figure 13 所示,在 L40 上 TP=2 的加速會更小一些:
PS:此外,LLM Inference 在 Prefill 階段的 AllReduce 通信量比較大,而在 Decoding 階段的 AllReduce 通信量比較小,作者并沒有進行相關對比實驗。
六、參考鏈接
- https://arxiv.org/abs/2412.04964
- https://developer.nvidia.com/blog/3x-faster-allreduce-with-nvswitch-and-tensorrt-llm-multishot/
- https://arxiv.org/abs/2211.10017
- https://arxiv.org/abs/2411.04330
本文轉載自 ??AI閑談??,作者: AI閑談
