字節 TileLink:編譯生成高效的計算和通信 Overlap Kernel
一、背景
筆者之前的文章(萬字綜述 LLM 訓練中的 Overlap 優化:字節 Flux 等 7 種方案)中詳細介紹過各種計算與通信 Overlap 的方案,這里進一步介紹字節最近發表的 TileLink,其中提到的大部分工作已經包含在我們之前的綜述中,建議優先閱讀,比如 CoCoNet、Centauri、Flux 等。
對應的論文:[2503.20313] TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives [1]
二、摘要
大規模深度學習模型通常需要分布式系統以實現高效的訓練與推理,分布式模型執行的基礎構建模塊是層內并行算子。提升層內并行算子性能的最有效方法在于實現計算與通信的 Overlap。這種 Overlap 可通過算子分解(Operator Decomposition)或 Kernel 融合(Fusion)兩種方式達成:
- Operator Decomposition 雖易于實現,但性能往往欠佳。
- 將通信 Kernel 與計算 Kernel 相融合則需深厚的專業知識且易出錯。
本文中,作者提出 TileLink,旨在高效編譯并生成計算-通信 Overlap 執行的 Kernel。TileLink 由前端(Frontend)和后端(Backend)構成:
- 在前端,系統通過以 Tile 為中心的原語將通信與計算的設計空間解耦并建立關聯。
- 在后端,將這些原語轉換為底層指令,整合通信與計算組件以實現 Overlap 執行。
實驗表明,TileLink 相較于非 Overlap 基線實現了 1.17x 至 20.76x 的加速,并在 GPU 上達到了與當前最優 Overlap 執行庫相當的性能水平。
三、引言
3.1 北大 Centauri
北大在 [ASPLOS 24.04] Centauri: Enabling Efficient Scheduling for Communication-Computation Overlap in Large Model Training via Communication Partitioning [2] 中介紹了 Centauri 框架,其構建了一個由三個固有抽象維度組成的切分空間:原語替換、拓撲感知組切分及工作負載切分。這些維度共同構成了一個全面的優化空間,用于高效 Overlap。為確定通信與計算的高效 Overlap,作者將混合并行訓練中的調度任務分解為 OP、Layer 和模型三個層次。
如下圖 Figure 3 所示,Centauri 的工作流程包含兩個核心環節:通信切分與層次調度。以 DP 與 FSDP 混合并行訓練為例:
- 通信切分:通過考量三個基本維度,生成潛在切分空間,并為每種集合通信選擇高效策略。
- 層次調度:在上述全面但較大的切分空間下,優化整圖的 Overlap 調度成為一項復雜的任務,為了簡化復雜的調度任務,作者將復雜的混合并行集合通信分解為三個層次,每個集合通信被分配至特定調度層級。各層級選取開銷較低的切分與調度方案,旨在實現整體優化 Overlap 方案。
有一系列類似 Centauri 的算子分解方法,其核心是:將通信和計算 Kernel 拆分為更小規模的同構 Kernel,隨后將其分配到多個通信-計算 Kernel 對中。這些拆分后的小 Kernel 可被調度到不同的 Stream 上,使得通信 Kernel 和計算 Kernel 能同時對切分的數據分片進行操作。
然而,類似上述算子分解的方法有一些局限性:
- 分解后的 Kernel 間的同步機制需要 Host 端介入,會在運行中引入不可忽略的開銷。
- L2 Cache 利用率降低、資源量化效率不足,導致分解后的 Kernel 性能可能出現惡化。
這里的資源量化效率不足(Resource Quantization Inefficient)是指計算資源切分不均衡等導致的浪費,如下圖 Stream-K([2301.03598] Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU [3])中提到的問題:
3.2 字節 Flux
字節在 [2406.06858] FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion [4] 中提出 Flux,旨在通過依賴計算隱藏 GPU 間的通信時延。Flux 將通信和計算操作分解為更細粒度的操作,并進一步融合成更大的 Kernel,從而在不損害 Kernel 效率的前提下有效隱藏通信。在融合 Kernel 的情況下,Flux 有望重疊高達 96% 的通信時間。
如下圖 Figure 5 展示 Flux 中 ReduceScatter 里 Overlap 與其他方案的差異。現有 Overlap 方案 Tm 理論上可能比原始方法 Tc 執行得更快,但通常情況下,Tm 仍慢于原始 GEMM 操作時間 Tg。主要原因在于,將一個 GEMM Kernel 拆分為一系列較小的 GEMM Kernel 會降低 GPU GEMM 的執行效率。GEMM 通常需要合理大小的矩陣才能充分利用 GPU 的計算能力。這些具有數據依賴性的小型 GEMM 操作序列進一步阻礙了 GEMM Kernel 通過 GPU 多路復用技術并行運行,因此,Tensor 并行度越高,GPU 上的 GEMM 效率越低。
相比之下,作者提出的技術不存在上述限制。作者的 Overlap 方案 Tf 能夠在極小開銷下實現與原始 GEMM 操作 Tg 相當的性能。其細粒度分解策略完美契合現代 GPU 設計特性,即通過上下文切換的 Warp 和數百個在 SM 間并發活躍的 Warp 來隱藏延遲,如下圖 Figure 5 底部所示。最終,作者的方法在不影響 GEMM 計算效率的前提下,僅在執行末尾引入少量通信開銷。
然而,雖然這種方式實現的 Kernel 效率很高,但是開發成本同樣很高,尤其是針對不同場景、模型可能需要開發特定的 Kernel。DeepSeek 可以做深度的 DeepEP、DualPipe 等優化的一個前提就是其模型、硬件相對恒定,可以一勞永逸。
四、方案
4.1 概覽
本文工作主要聚焦于層內并行,為了說明 TileLink 的優勢,作者以 MLP 的 Tensor Parallelism(TP) 為例,如下圖 Figure 1 所示,其實現包含 AllGather + GEMM(AG+GEMM)與 GEMM+ ReduceScatter(GEMM + RS),其配置與 LLaMA-7B 一致:
如下圖 Table 2 所示,采用不同技術方案的性能進行對比,其中 Non-Overlap 為直接使用 cuBLAS 和 NCCL 的無 Overlap 方案;Decomposition 則為采用算子分解技術。可以看出,Decomposition 是性能最差的,Fusion 方案在 AG + GEMM 中最優,TileLink 在 GEMM + RS 中最優,同時 AG + GEMM 與 FLUX 性能接近(約達 99%)。同時,FLUX 需要 2000 行 CUDA 代碼,而 TileLink 僅需 200 行 Python 代碼,編程效率提升 10x。
PS:之前的 CoCoNet([ASPLOS 22] [2105.05720] Breaking the Computation and Communication Abstraction Barrier in Distributed Machine Learning Workloads [5]) 和 Dist-Einsum(Overlap Communication with Dependent Computation via Decomposition in Large Deep Learning Models | OpenReview [6]) 也可以生成 Overlap Kernel,但是其只能生成特定 Overlap Pattern 的算子,不夠靈活。
4.2 前端原語(Frontend Primitives)
4.2.1 解耦設計空間
設計計算+通信融合 Kernel 存在兩種方式:一種是將兩部分優化選擇緊密耦合;另一種是解耦計算 Kernel 和 通信 Kernel 設計。本文的 TileLink 選擇后者,因為其解構設計空間能為 Kernel 設計提供更多靈活性,從而可能獲得更優性能。(PS:是否也有可能喪失聯合設計的優勢,比如更加均衡的資源分配?)
解耦設計空間分為 3 個子空間:
- 分塊尺寸(Tile Size):如下圖 Figure 2a 所示,通信組件每次傳輸 128x128 的 Tile;計算組件每次處理 128x256 的 Tile。Tile 的大小與使用的處理核心數相關,比如通信組件占用更多核心時使用較小的 Tile 可以更充分的利用全部核心資源;反之,核心數較少時更大的 Tile 更加高效。
- 分塊順序(Tile Order) :如下圖 Figure 2b 所示,通信組件可以與計算組件采用不同的分塊順序。分塊順序的選擇也存在權衡:若計算組件等待多個 Rank 的數據分塊,則能在處理更大數據塊時獲得更好的 Cache 效率,但可能等待時間變長;反之,僅等待單個 Rank 的數據分塊,可提早開始計算,但整體計算效率可能降低。圖中例子為通信采用 Ring 順序而每次迭代等待 2 個 Rank 數據。
- 資源映射(Resource Mapping):如下圖 Figure 2c 所示,通信與計算組件可映射到不同單元或相同單元。比如,如果通信組件使用 Copy Engine(DMA),可以避免與計算組件的資源沖突,但需要承擔 Host 帶來的額外開銷;但是如果采用計算核心執行數據拷貝,則可以消除 Host 開銷,但可能引發資源沖突,這適用于計算組件無法充分利用所有處理核心的場景。
4.2.2 Tile 為中心的基礎原語
解耦通信與計算的設計空間也會引入同步的挑戰。由于這兩個組件采用不同的分塊尺寸、分塊順序和資源映射方案,實現二者同步需要進行復雜的底層編程并插入通信指令。以 GPU 為例,要求使用諸如 ld.global.acquire 和 red.release 等特殊指令。然而,這類指令的編程模型和代碼生成編譯器的工作機制存在根本性差異,現有編譯器普遍缺乏對內存一致性模型的原生支持。
- ld.global.acquire:獲取語義,確保之后的操作不會提前執行,確保讀取的變量是最新的,防止 CPU 或其他 GPU 線程的舊值污染數據。
- red.release:釋放語義,確保之前的寫入對其他線程可見,確保數據在此操作之前全部寫入,防止寫入亂序執行。
- 這兩個指令通常用于同步機制,特別是生產者-消費者、互斥鎖、信號量等場景,以保證不同 GPU 線程間的正確通信。
為解決上述問題,TileLink 提供了一套以 Tile 為中心的基礎原語。這些原語引入了內存一致性(Memory Consistency)語義,并遵循編譯器采用的 Tile 級抽象,與現有框架提供的以算子為中心的原語形成顯著區別。如下圖 Figure 3 所示,TileLink 原語分為信號原語(Siganl Primitive)和數據原語(Data Primitive)兩大類,每類均包含 Device-side 原語和 Host-Side 原語兩個子類。
涉及的所有原語如下表 Table 3 所示:
4.2.3 信號原語
信號原語:旨在管理通信和計算之間的屏障,包括:
- producer(peer)_tile_notify:生產者或 Peer 通知
- consumer(peer)_tile_wait:消費者或 Peer 等待
- rank_notify(wait):Rank 通知和等待
在 Device-side:
- producer_tile_notify 和 consumer_tile_wait 適用于生產者-消費者關系,例如 AllGather 與 GEMM 運算中各 Tile 的交互;
- peer_tile_notify 和 peer_tile_wait 主要用于跨不同 Rank 的同一算子 Tile,使用戶能夠構建多樣化的 Tile 執行順序。
在 Host-side:
- rank_notify 和 rank_wait 用于管理 Copy Engine 和計算核心間的同步屏障。當通信任務映射至 Copy Engine 時,這些原語可有效協調通信與計算間的 Tile 執行順序。如上圖 Figure 3a 所示。
Notify 原語需通過 Mode Argument 或 Rank argument 明確待通知的遠端 Rank 范圍。TileLink 為 Mode Agrument 提供兩種選項:p2p 和 broadcast。
- p2p 僅通知單個目標 Rank,其數值由給定 Tile 標識(tile_id)在全局張量視圖中的偏移量計算得出;
- broadcast 則向所有 Rank 發送通知信號。
內存一致性:在并行執行過程中,不同進程/線程執行的內存操作可能以非一致順序對其他進程/線程可見。內存一致性模型通過設定約束條件,確保各進程/線程觀測到的操作順序不存在歧義。信號原語提供了嚴格的內存一致性語義:
- 通知類原語具有釋放語義(release semantics),保證所有在 producer(peer)_tile_notify 和 rank_notify 之前的內存訪問操作不得被重排到這些通知原語之后;
- 等待類原語則具有獲取語義(acquire semantics),確保所有在 consumer(peer)_tile_wait 和 rank_wait 之后的內存訪問操作不得被重排到這些等待原語之前。
這種嚴格的內存一致性約束在后端編譯階段同樣需要予以考慮。
4.2.4 數據原語
數據原語促進了數據傳輸過程,主要包括 tile_push(pull)_data 和 rank_copy_data 兩類原語。這些原語精確控制著傳輸數據的資源映射與 Tile 大小。
- Device-side 的 tile_push(pull)_data 原語將通信映射至處理核心。
- Host-side 的 rank_copy_data 原語則將通信映射至 Copy Engine。
數據傳輸存在拉取(pull)與推送(push)兩種模式,各自適配不同的同步機制:
- 在 pull 模式下,生產者從所有其他 Rank 讀取數據,并通過本地屏障通知其消費者;
- 與之相反,push 模式允許生產者將本地數據寫入所有其他 Rank,同時向遠端消費者發送數據到達通知。
如上圖 Figure 3b 清晰展示了兩種模式的差異。模式選擇可能影響性能表現,具體取決于數據形態、分塊策略及可用硬件資源等要素。值得注意的是,rank_copy_data 原語通過 P2P 復制技術支持雙模式運行,其數據傳輸方向由源指針與目標指針的排列順序顯式指定。
4.3 后端映射(Backend Mapping)
TileLink 后端負責將通信與計算組件共同編譯為底層設備代碼。為實現分布式系統的代碼生成,TileLink 采用了一種以計算單元為核心的映射技術,該技術能夠將通信模塊與計算模塊進行關聯整合。
TileLink 采用以 Tile 為中心的映射方法,將前端原語編譯為底層代碼。以 Tile 為中心的映射包含三個組成部分:
- 形狀映射(shape mapping):將每個 tile_id 與特定的 Tensor Shape Tile 相關聯。
- Rank 映射(rank mapping):將每個 tile_id 與 Device Rank 相關聯。
- 通道映射(channel mapping):為每個tile_id 分配通信屏障(communication barrier)。
作者分別用 fS、fR、fC 表示這三種映射。根據工作負載類型的不同,應采用不同的映射函數。作者將不同映射劃分為兩類:
靜態映射(Static Mapping):指可在編譯時確定的映射關系,通常用于數據分片策略固定的場景,例如 Tensor-Parallel MLP 和 Sequence-Parallel Self-Attention。作者采用仿射運算(Affine Operation)處理靜態映射(此時 fS、fR、fC 均為仿射函數)。以包含 R 個設備(每 Rank 對應 C 個通道/屏障)的系統上執行 AllGather(pull 模式)+ GEMM(問題規模 M×N×K)為例:生產者 AllGather 操作的 Tile 尺寸為 Tmp × Tnp,輸入 Tensor 沿 M 維分片。給定生產者 Tile 的 tile_idp,其形狀范圍、源 Rank 及通道可通過以下公式計算。類似地,可以計算出從消費者 tile_idc 到形狀范圍、 Rank 和通道的映射關系:
動態映射(Dynamic Mapping):是指在運行時計算的映射關系,這對于具有動態數據分片需求的工作負載至關重要。例如,在 MoE 數據分片策略中,動態路由決定了數據分布,每個 tile 可能需要來自其他任意 Rank 的 Token。在編譯時無法確定需要從哪些 Rank 收集數據或在哪個通道等待屏障同步。因此,必須在運行時計算這些映射關系。為支持動態映射,TileLink 將這些映射轉換為查找表,其值可在運行時填充,而對這些查找表的訪問操作則在編譯時確定。從形式化角度來看,動態映射如下所示(其中 fS_low,fS_high,fR 和 fC 是查找表,其值在運行中動態調整):
內存一致性編譯:在后端編譯過程中,前端具有內存一致性語義的原語被編譯為相應的設備指令(如 ld.global.acquire 和 red.release)。然而,直接翻譯這些原語并不足以確保內存一致性。對于大多數計算 Kernel,采用多級流水線技術來提升負載-計算平衡并優化整體性能。將原始程序編譯為多級流水線版本需要進行算子重排,在此過程中某些內存訪問操作可能會意外地被重排至 TileLink 原語之前或之后。為解決這一問題,TileLink 在其原語與后續 load/store 操作之間建立了嚴格的數據依賴關系,從而確保其原語能夠通過流水線處理階段被正確重排序和展開。
其他編譯優化:除上述技術外,TileLink 還采用單設備優化策略以實現高性能,該策略在已有研究中得到充分論證。優化主要體現在內存優化與流水線優化兩方面:
- 內存優化通過自動分配片上寄存器緩存和計算用共享存儲緩沖區,對全局緩沖區的數據訪問進行合并操作,并重構共享存儲器訪問模式以避免存儲體沖突;
- 流水線優化則通過重組數據 load/store 操作與計算任務,構建多級流水線架構。其中,本地數據拷貝被映射至專用異步引擎(如 GPU 的 TMA),而計算任務則被分配至高性能運算單元(如 GPU 的 Tensor Core)。
4.4 Kernel 設計
為展示 TileLink 的靈活性與普適性,作者闡釋了如何為 GEMM + Ring ReduceScatter、AllGather + MoE 以及 AllGather KV + Self Attention 機制設計 Overlap 計算 Kernel。這三個案例具有代表性:它們分別采用了不同的分片順序(Ring 和 All2All)、不同的映射策略(靜態與動態)以及不同的硬件資源(Device-side 和 Host-side)。
如下圖 Figure 4 展示了 GEMM + Ring ReduceScatter Kernel 的偽代碼實現,該案例采用靜態映射策略,演示了生產者-消費者和 P2P 雙向通信的編程范式。
- 其中計算與通信均采用 SM,分配了 20 個 SM 專用于通信(見第 1 行)。
- 生產者 GEMM 將部分計算結果存儲于本地 Tensor,并通過 producer_tile_notify 通知消費者(第 9 行)。
- 消費者 ReduceScatter 通過 consumer_tile_wait(第 16 行)等待生產者就緒。
- 一旦數據可用即執行 local reduce 操作(第 20 行),并將部分結果通過 tile_push_data 傳遞給前序節點(第 24 行)。
- 節點間的信號控制通過 peer_tile_wait(第 19 行)和 peer_tile_notify(第 26 行)原語實現。
如下圖 Figure 5 展示了 AllGather + MoE 的偽代碼實現。
- 同樣采用 20 個 SM 處理通信任務(第 1 行)。值得注意的是,MoE 需要基于動態路由(輸入中的 topk_ids)為每個 token 選擇專家,必須采用動態映射。因此使用 table 數據結構存儲形狀映射、Rank 映射及通道映射的查找表。所有相關原語均需以 table 為參數,以確保 TileLink 能基于動態映射生成正確代碼。
- 此外,load 原語需要借助 table 中的形狀映射來收集當前分片所需的正確 token(第 11 行)及其對應的 topk_ids(第 12 行)。
如下圖 Figure 6 展示了 AllGather KV + Self Attention(序列并行)的偽代碼。本案例中通信操作通過 Copy Engine 實現,采用 Host 原語來觸發 Copy Engine。通信與計算分別在兩個獨立的流上執行:
- 通信部分通過 rank_copy_data 原語完成,其分塊尺寸為 KV Cache 序列長度(S)除以總 Rank 數(WORLD_SIZE)。
- 計算部分則采用不同的分塊尺寸。通過基于分塊的 Kernel 映射機制,確保通信與計算環節間的屏障操作正確執行。
4.5 實現
TileLink 基于 Triton,使用 Python 語言實現。作者在 Python 層面實現了以計算塊為中心的原語操作,從而擴展了 Triton 的語言特性,而面向計算塊的映射機制則通過 Python 抽象語法樹(AST)轉換實現。其實現方案可輕松適配至 TVM、MLIR 等其他編譯器框架。
如下圖 Figure 7 所示,編譯器輸入為融合 TileLink 原語與 Triton 原生原語的純 Python 程序。通過特殊參數 BlockChannel 為計算和通信提供以計算塊為核心的映射上下文,BlockChannel 封裝了分布式映射元數據,包括當前進程 Rank、總 Rank 數、同步屏障配置及生產者/消費者計算塊關系等。
- Python 程序經解析生成 AST 后轉換為 Triton 中間表示(IR),在此過程中 BlockChannel 參數被分解,利用其內嵌元數據構建面向計算塊的映射關系,TileLink 原語則轉換為 Triton 的 ElementwiseInlineAsmOp 操作。
- 隨后 Triton IR 被進一步降級為 Triton GPU IR 和 TileLink 新增的 Distributed IR,后者用于將通過 ElementwiseInlineAsmOp 表達的特殊指令轉換為 LLVM IR,最終編譯為適用于 NVIDIA GPU 的 PTX 代碼。
- 通過將 LLVM IR 轉換為目標架構特定的底層匯編,可支持更多后端硬件。
- 運行時:
- 采用 NVSHMEM 初始化分布式執行環境并分配共享內存。
- 生成的代碼在所有進程上啟動以執行并發計算與通信。
- 運行結束后正確釋放共享內存空間。
五、評估
如下圖 Figure 8 所示,作者在 8xH00 集群上測試:
- 對 AG+GEMM 場景,Async-TP PyTorch 由于分解后的 GEMM 運算規模過小無法充分占用設備資源,未能實現加速效果。FLUX 憑借高度優化的實現取得了最高加速比(相較于 cuBLAS + NCCL 達1.34x)。TileLink 同樣實現了優于 cuBLAS + NCCL 的加速效果(1.27x),達到 FLUX 性能的 94.5%。
- 對于 GEMM + ReduceScatter 場景,TileLink 展現出最佳性能:較 cuBLAS + NCCL 提升 1.25x,較 Async-TP PyTorch 提升 2.22x,較 FLUX 提升 1.28x。
如下圖 Figure 9 所示,MoE 層相較于 MLP 層復雜度顯著提升,在編譯階段需進行動態映射。該層可分解為兩個核心部分:AG + Gather + Group GEMM 與 Group GEMM + Scatter + Topk Reduce + RS。這兩類算子可融合為 Group GEMM Kernel,vLLM 已實現此類融合運算。
- 在第一部分:TileLink 憑借通信-計算 Overlap 優化,在 vLLM 基礎上進一步實現 1.51x 平均加速。
- 在第二部分:TileLink 相較 vLLM 獲得 1.31x 平均加速,較 CUTLASS + NCCL 組合提升 10.56x。
- 需特別指出,FLUX、Async-TP PyTorch 等現有庫均不支持 MoE 層 Overlap 執行,而 TileLink 憑借靈活的原語體系與動態映射機制實現了該功能支持。
如下圖 Figure 10 所示,作者針對 16K 到 128K 序列長度的 Self Attention 機制進行了評估。實驗表明,在所有序列長度條件下,TileLink 方案相較 PyTorch 非 Overlap 實現(Torch)與RingAttention(RingAttn)均展現出穩定的加速優勢。經量化分析,TileLink 平均可獲得 5.04x 于Torch、1.97x 于 RingAttn 的性能提升。
作者將 TileLink 集成至 PyTorch 框架,并在 H800 集群上對 8 種不同 LLM 進行端到端性能評估。
- 首先在單節點(8×H800 GPU)環境下進行測試,結果如下圖 Figure 11 左半部分所示。前五種為 Dense 模型,后三種為 MoE。其中 Qwen1.5 采用 MoE 共享專家機制,通過將 MLP 層與 MoE 層合并來實現共享專家支持。實驗設置 Batch Size 為 4、序列長度 8192。結果表明, TileLink 相較 PyTorch 實現平均 1.32x 加速。Dense 模型平均加速比為 1.20x,與單層 MLP 加速效果一致——盡管 Self Attention 獲得顯著加速,但端到端性能仍由 MLP 層主導。MoE 模型平均加速比為 1.54x,低于單層 MoE 加速效果,因其 MLP 層與 MoE 層各占約 50% 執行時間,最終加速比介于二者之間。
- 在多節點部署評估中,鑒于節點間帶寬限制,采用節點內 TP 與節點間 DP 的混合策略。雙節點(各 8×H800 GPU)測試結果與單節點基本一致(Batch 規模倍增),整體加速比為 1.29x,因節點間通信開銷略有下降。
六、參考鏈接
- ??https://arxiv.org/abs/2503.20313??
- ??https://dl.acm.org/doi/10.1145/3620666.3651379??
- ??https://arxiv.org/abs/2301.03598??
- ??https://arxiv.org/abs/2406.06858??
- ??https://arxiv.org/abs/2105.05720??
- ??https://openreview.net/forum?id=MIJtDiMUX9??
本文轉載自???AI閑談???,作者:AI閑談
