Meta 新作:FlashAttention 的數值偏差有多大?
一、背景
最近 Meta 的研究員開發了一個新的框架來了解 LLM 訓練中數值偏差的影響,并基于該框架評估了 LLM 中廣泛采用的 FlashAttention 的數值偏差。
對應的論文為:[2405.02803] Is Flash Attention Stable?
PS:其實論文很簡單,結論也很簡單:使用 FlashAttention 相比 Baseline Attention 確實會帶來數值偏差。但帶來的數值偏差比從 FP32 到 FP16 的數值偏差小得多,甚至小于不同初始化方法帶來的偏差。吐槽一下,論文中的圖都比較模糊。
二、摘要
LLM 預訓練的代價很高,也更加的復雜。很多 LLM 在預訓練中都遇到了訓練過程不穩定的情況,通常表示為損失的毛刺(Spike)。數值偏差(Numeric Deviation)被認為是導致這種訓練不穩定的潛在原因,但由于訓練的成本很高,量化這一點非常有挑戰性。
本文中,作者開發了一種系統性的方法來理解數值偏差的影響,并使用廣泛采用的 FlashAttention 來驗證了該框架。作者發現,與 Baseline Attention 相比,在單個前向傳播中,BF16 下的 FlashAttention 會有超過一個數量級的數值偏差。然而,使用基于 Wasserstein 距離的數據驅動分析來提供數值偏差對訓練過程中模型權重影響的上限,發現 FlashAttention 中的數值偏差比低精度訓練的影響小 2-5 倍。
三、引言
3.1 數值精度
如下圖為常見的浮點數值精度,其中 sign 表示符號位,exponent 表示指數位,fraction 表示尾數位。相比 float32,float16 的指數位和尾數位都更小,而 bfloat16 的指數位和 float32 相同,只是尾數位更少。因此,通常 float32 轉 float16 時通常會帶來較大的精度損失,而 float32 轉 bfloat16 通常只需要做小數位的截斷,損失相對較小。現在的 LLM 預訓練中通常都會使用 bfloat16。
- Float32:指數位 8 位,尾數位 23 位,數據范圍為[1.18e-38, 3.40e+38]
- float16:指數位 5 位,尾數位 10 位,數據范圍為[6.10e-05, 6.55e+04]
- bfloat16:指數位 8 位,尾數位 7 位,數據范圍為[1.18e-38, 3.39e+38]
3.2 數值誤差
在浮點數的計算中會存在兩種常見的誤差:
- 溢出誤差(Overflow Error):浮點都有一個有限的表示范圍,當計算結果超出這個表示范圍時就會產生溢出錯誤,往往表現為無窮大。比如,令 float a = FLT_MAX * 2,此時 a 的值為正無窮大。
- 舍入誤差(Rounding Error):浮點數有固定的有效位數,當一個數值不能被精確表示時,就會被舍入到最接近的可表示的浮點數。這種輸入在數值計算中是不可避免的,因為大多數實數在計算機中無法被精確表示。比如在 C 中打印 0.1f,printf("a = %.20f\n", 0.1f),其輸出結果為 0.10000000149011611938,是一個近似值。
除此之外,有時也會提到下溢誤差(Underflow Error):當一個非常小的非零結果小于浮點數表示范圍下限時發生,通常導致結果被舍入為零。
由于 float16 和 bfloat16 的不同指數位和尾數位,也就導致它們出現誤差的場景不太一樣。
- float16:指數位較少,尾數位較多,表示范圍有限,但表示精度更高,因此更容易發生溢出誤差。
- bfloat16:指數位較多,尾數位較少,表示范圍更大,但表示精度有限,因此更容易發生舍入誤差。下溢誤差也更多一些。
3.3 訓練損失毛刺
在 Meta OPT、BigScience Bloom、Google PaLM、TII Falcon 以及智源 GLM 訓練中都出現了訓練損失出現毛刺的情況,也有一些有效的手段可以緩解,但依舊不知道其根因。比如 Google PaLM 中驗證了其并非是單個樣本導致的。
如下圖所示,是 [2211.05100] BLOOM: A 176B-Parameter Open-Access Multilingual Language Model 中遇到的毛刺現象:
3.4 評估指標
Wasserstein 距離,也稱為 Earth Mover’s Distance (EMD),是一種衡量兩個概率分布之間差異的方法。這種距離的直觀含義是,將一個概率分布轉變成另一個概率分布所需要的“工作量”或“成本”,其中“工作量”可以理解為將一堆形狀不同的沙子(一個概率分布)鏟動并重塑為另一堆沙子(另一個概率分布)所需要的努力。
Wasserstein 距離基于最優運輸理論。給定兩個概率分布 P 和 ??,以及一個成本函數 ??(??,??),Wasserstein 距離定義為將分布 P 轉變為 Q 所需的最小成本。數學上,它表示為:
這里的 π 是 P 和 ?? 之間的所有可能的聯合分布的集合,而 Π(P,Q) 表示所有這些聯合分布中,邊際分布分別是 P 和 Q 的集合。
相比其他距離度量(如歐氏距離或 KL 散度),Wasserstein 距離的一個主要優勢在于其能夠更加有效地處理概率分布之間的微小變化,特別是當這些分布不重疊或僅部分重疊時。這使得 Wasserstein 距離在數據稀疏或異構的情況下特別有用。
四、方法&實驗
4.1 方法
作者開發了一個 microbenchmark 來隔離和研究 FlashAttention 引起的數值偏差。其設計如下圖 Fig 2 所示,在原始的 FlashAttention 中只支持 FP16 和 BF16 格式,因此作者重新實現了 FlashAttention,以便分析不同的數值精度的影響。作者進一步修改模型,可以在每次調用 Attention 時計算 Baseline Attention 和 FlashAttention 的注意力矩陣輸出,從而可以使用最大差異(max difference)以及 Wasserstein 距離來度量差異。作者也進行了一系列訓練來度量整個訓練過程中模型權重的差異。
4.2 數據類型的影響
如下圖 Fig.3 所示,作者對比了不同數據類型下 Baseline Attention 和 FlashAttention 的數值偏差,可以看出,數值精度越高,偏差越小:
為了進一步分析這種數值偏差,作者探索了序列長度對數值偏差的影響,其中會保持 FlashAttention 的 tile 大小和 SRAM 大小相同。如下圖所示,隨著序列長度的增加,數值偏差也會適當增加。其中左圖(a)表示最大誤差,右圖(b)表示誤差的均值。由于序列變長,也就需要更多的 tile,相應也有更多的 resaling,這也就可能產生更多的誤差:
4.3 算法配置的影響
如下圖 Fig 6 所示,作者進一步探索了 FlashAttention 中不同配置的影響:
- (a)和(c)針對不同的 Block/tile Area 大小的影響,使用比較大的 Block 后 Baseline Attention 和 FlashAttention 的差異很小,主要是因為 rescaling 計算更少一些。
- (b)使用 Square Block 對 Baseline Attention 和 FlashAttention 的影響不大。?
4.4 模型權重的變化
作者進一步驗證了訓練中模型權重的變化(對比 Baseline Attention 和 FlashAttention),如下圖 Fig 7 所述,不管是最大誤差還是 Wasserstein 距離都會隨著訓練的迭代而逐漸變大,并且趨勢類似:
如下圖 Fig.8 所示,作者進一步驗證了整個訓練中其他變量帶來的模型權重的偏差。可以看出,雖然 Baseline Attention 和 FlashAttention 會導致權重產生誤差,但是其甚至比不同初始化方法帶來的誤差還小,更是遠小于 FP16 vs BF16 和 FP16 vs FP32 帶來的誤差:
五、參考鏈接
本文轉載自 ??AI閑談??,作者: AI閑談
