Intel Smooth-SwiGLU:FP8 LLM 訓練,34% 加速
一、背景
本文中我們繼續介紹一個 Intel 最新的關于 FP8 訓練相關的工作,其在一定程度上分析并解決了 FP8 訓練中的不收斂問題,進一步推進了 FP8 訓練落地(尤其是在 H100/H800 GPU 上)的可行性。
對應的論文:[2409.12517] Scaling FP8 training to trillion-token LLMs [1]
二、摘要
本文中,作者首次在 2T Token 的數據集上使用 FP8 精度訓練了 LLM,比以前的限制增加了 20 倍。通過這些擴展訓練實驗,作者發現了 FP8 訓練中的關鍵不確定性,這些不確定性在早期持續時間較短的訓練中是無法觀察到的。作者進一步追溯到 SwiGLU 激活函數的異常值放大問題。有趣的是,從分析和經驗上表明,這些放大只發生在較長的訓練中,并將其與 SwiGLU 的權重對齊過程聯系起來。
為了解決這個新發現的問題,作者引入了 Smooth-SwiGLU,這是一種新穎的修改,可確保穩定的 FP8 訓練而不改變函數的行為。作者還首次演示了兩個 Adam 優化器參數(一階矩和二階矩)的 FP8 量化。
結合這些創新,作者在 256 個 Intel Gaudi2 加速器上使用 FP8 精度成功訓練了一個 7B 參數量的模型,實現了與 BF16 基線相當的結果,同時提供了高達 34% 的吞吐提升。
三、引言
3.1 浮點數值表示
我們之前的文章中提到過,雖然都是 FP8 精度,但是不同硬件上存在不同的表示方式,主要包括 E5M2 和 E4M3,其中 E 表示指數位(決定了動態范圍),M 表示尾數位(決定了表示精度)。此外,雖然都是 E5M2 或者 E4M3,不同的硬件可能采用不同的格式。比如 NVIDIA GPU 上的 E5M2 符合 IEEE 754 Style,而 E4M3 卻不符合 IEEE 754 Style,可以稱為 ARM-Intel-Nvidia Style。此外,AMD-Graphcore-Qualcomm 的表示也有所不同。如下圖所示,IEEE 754 Style 的 E4M3 范圍為 [-240, 240],而 ARM-Intel-Nvidia Style 的 E4M3 范圍是 [-448, 448]:
PS:由于 Intel 和 NVIDIA 采用一致的表示方式,這也就意味著 Intel 的結論能夠很容易擴展到 NVIDIA 的 GPU。
3.2 SwiGLU
在 [2002.05202] GLU Variants Improve Transformer [2] 中作者提出了在 Transformer 模型中使用各種 GLU 變體激活,也提到其在許多下游理解任務上獲得了更好的結果。然而,作者也提到并沒有解釋為什么這些修改會有效,將其歸功于上帝的恩賜。在后續 Google 的 PaLM,Meta 的 LLaMA 系列等模型中廣泛采用了 SwiGLU 激活。
如下圖所示為 FFN 中應用 SwiGLU 的公式:
其中 Swish 為激活函數,可以表示為:
其 Pytorch 的實現也很簡單,如下所示,這里用 silu 代替了 SwiGLU,對應 β 為 1:
因為這里有三個參數:w1,w2,w3,為了保證總參數量和正常 FFN 一致,所以 LLaMA 中這里的 Hidden Dim 不是 4d,而是 4d*2/3=8d/3。
3.3 GLU 類激活的離群點
如下圖 Figure 1 所示,在 [2405.14428] Mitigating Quantization Errors Due to Activation Spikes in GLU-Based LLMs [3] 中作者發現,各種 GLU 變體的激活函數容易在特定層(比如基于 SwiGLU 激活的 FFN 的最后一個 Liner 層的輸入)出現激活的 Spike。此外,作者發現這些激活的 Spike 與中間層隱藏狀態(Hidden Stage,每個 Transfomer Block 的輸出)之間也存在高度相關性。并且 FFN 可能會通過殘差連接中的加法運算放大 Hidden Stage。一旦 Hidden Stage 被放大,它就會在各層中持續存在,直到之后的層中再次遇到激活 Spike。
3.4 延遲縮放
在 FP8 的 Per Tensor Scaling 技術中,有兩種常見的方式:Just-in-time Scaling 和 Delayed Scaling(可以參考 NVIDIA Transformer Engine 中的實現 Using FP8 with Transformer Engine [4])。
- Just-in-time Scaling(實時縮放):直接計算 Tensor 絕對值的最大值(amax),然后得到 Scaling 值,再對 Tensor 進行 Scaling。此種方法更加精確,但是,額外引入的開銷會大幅降低 FP8 帶來的收益。
- Delayed Scaling(延遲縮放):核心思路是使用額外的 Tensor 來存儲之前的 amax 歷史,然后根據歷史最大值估計當前的最大值。
如下圖為 NVIDIA Transformer Engine 中的 Delayed Scaling 實現方案,amax history 最多可以存儲 1024 個 history。在進行當前 Tensor 的 Scaling 操作時,使用當前 Tensor 之前的 amax history 來預測當前的 amax,然后再進行 Scaling 操作;Scaling 操作的同時會計算當前的 amax,并更新 amax history。
四、方案
4.1 洞察
上面提到的 GLU 類激活引出的離群點(Outlier)問題會為 FP8 訓練帶來很多挑戰。本文中,作者揭示,在大規模數據集上訓練 LLM 的后期階段,這些異常值變得尤為顯著。
如下圖 Figure 1 所示,(a)為訓練的起始階段,沒有 Outlier;(b)為訓練 200B Token 之后,出現偶發性的 Outlier。這些 Outlier 僅在訓練中處理了很長一段時間才出現,此現象對維持訓練中的數值穩定性帶來極大挑戰,進一步增加了 FP8 訓練穩定性的難度,尤其是像 Megatron-LM 中廣泛采用的延遲縮放方案中(如上述的介紹,這些方案假設了迭代期間的一致性)。
如前所述,SwiGLU 激活函數可能導致 FFN 組件的最后一個 Linear 層的輸入出現 Outlier。在使用 FP8 訓練,并采用延遲縮放技術時,SwiGLU 引發的 Outlier 會打破延遲縮放的統計一致性假設,導致訓練過程的不穩定性。如下圖 Figure 3 所示,作者展示了在 FFN 的最后一個 Linear (也就是 SwiGLU 的輸出)禁用量化后的訓練收斂性,LLaMA2 FP8 的訓練能夠成功地在大規模數據集上收斂,從而解決先前觀察到的發散問題,這也驗證了 SwiGLU 對 FP8 訓練穩定性的影響。
4.2 SwiGLU 相關問題證明
如下圖所示為 SwiGLU 的定義,其中,SwiGLU 由輸入 x 與權重 w1、w2 進行乘積,分別得到 xTw1 和 xTw2。然后使用 Swish 激活函數對 xTw2 進行變換,將其與 xTw1 相乘。
其他標準激活函數(如 ReLU、GeLU 和 Swish)在輸入幅度較大時最多是線性的。這意味著當輸入 u 逐漸趨于正無窮或負無窮時,這些激活函數的比值(即 ∣f(u)/u∣)會趨于小于等于 1 的某個值。然而,SwiGLU 是一個二次函數,可以達到更大的值,尤其在權重 w1、w2 相互“對齊”時(例如 w1=w2,并且 ||w1||=1)。
當 w1=w2 時,上式可以表示為:
假設 xTw=c,則當 c 較大時,σ(c) 趨近于 1,此時下述結果約等于 c2,也就是上述所說的二次放大特性。
作者也進一步通過理論分析證明了上述 w1、w2 相互“對齊”現象,這里不再贅述,具體可以查看論文。當然,作者也進一步通過實驗驗證了相關問題。如下圖 Figure 2 所示:
- (a):LLaMA2-7B 模型在 BF16 和 FP8 精度下的訓練損失,其中 FP8 的訓練在達到 200B Token 后開始出現明顯的發散現象。
- (b):展示了某一個特定 Channel 在訓練過程中 w1 和 w2 范數的動態變化及其相關性。
- (c):某一個 Outlier 通道在訓練初期(8B Token)與后期(330B Token)w1 和 w2 元素散點圖。可以看出,后期相關性顯著提高,趨近于 w1=w2。
- 某一個 Outlier 通道在訓練初期(8B Token)與后期(330B Token)w1 的直方圖分布。?
除了上圖 (c) 那樣的正相關性外,作者也觀察到了明顯的負相關性,也就是趨近于 w1=-w2,如下圖所示:
4.3 Smooth-SwiGLU
為了在解決 Outlier 問題的同時保持完整的 FP8 加速,作者提出了 Smooth-SwiGLU 方法。如下圖 Figure 4 所示,其展示了 Smooth-SwiGLU 的核心理念:對 SwiGLU 函數的線性分支施加一個縮放因子,并在最后一個 Linear 后將其重新縮放。此方法防止了輸入到最后一個 Linear 的量化過程中出現 Outlier,同時保留了 SwiGLU 激活函數的整體功能,使能夠在整個網絡中充分利用 FP8 精度。
為降低計算開銷,作者采用了一種高效的并行方法來計算縮放因子 si,也就是 Per Channel 量化:
1. 將 Tensor 分割成若干塊,每塊對應一個通道。
2. 對每個塊(通道),并行計算其最大值。
3. 利用這些每個通道的最大值,確定各通道的獨立縮放因子 si。
此方法實現了高效的逐通道縮放,因為每個通道的縮放因子是獨立并行計算的。相較于 Linear 層中的矩陣乘法,這種方法的計算成本適中,尤其是在并行化處理下,即便在非優化實現中也是如此。在推理階段,這些縮放因子可以合并到包含 SwiGLU 層及其后 Linear 的 FFN 的第一和第三Linear 的權重中。
五、實驗&結果
5.1 FP8 優化器
Adam 優化器及其變體在深度學習中得到廣泛應用,Adam 優化器的一個關鍵特征是其存儲兩個矩,傳統上采用高精度(FP32),這顯著增加了內存開銷,尤其對于大規模模型而言。盡管先前研究中已證明將其一階矩降至 FP8 精度的可行性,但仍保留二階矩為 FP16。本文中作者則更進一步,成功將二階矩也量化至 FP8,顯著提升了大型語言模型優化器的效率。
如下圖 Figure 5 所示,作者探索發現一階矩 Mom1 采用 E4M3,二階矩 Mom2 采用 E5M2 可以很好的維持訓練精度:
5.2 訓練效果實驗
如下圖 Figure 6 所示,作者在 256 個 Gaudi2 上訓練 LLaMA2 7B 模型,共訓練了 330B 左右 Token。本文的 Smooth-SwiGLU + FP8 優化器可以和 BF16 訓練維持相當的 Loss,而傳統的 FP8 訓練在 200B Token 時開始發散。
如下圖 Table 2 所示,作者進一步對比了下游任務的 Zero Shot 精度以及困惑度,可以看出,本文 FP8 訓練出的模型可以很好的保持精度:
5.3 訓練速度
如下圖 Table 3 所示,可以看出,本文方案在保持收斂性的同時獲得了更高的加速比,相比 BF16 訓練可以加速 33.52%(不及 FP8 主要是因為引入了一些額外的開銷)。
雖然額外引入了一些計算開銷,但顯存并沒有明顯增加:
六、參考鏈接
