Sample Packing:長序列 LLM 訓練的 Attention 問題及優化
一、背景
之前看過部分 Megatron-LM 的源碼,也詳細分析過對應的 Dataset 和 DataLoader,想當然的認為在 LLM 預訓練時會使用 Document Level 的 Mask,也就是常說的 Sample Packing 技術。最近我們在做長序列訓練相關工作時發現并非如此,并且出現了一些很奇怪的性能問題,因此重新看了相關工作,并進行了部分實驗。
Sample Packing 中有很多可以討論的技術點,比如 Attention 的實現和優化,Sample 的組合及負載均衡問題(有點類似調度問題)以及不同方案對效果的影響等。我們這里只是先簡單介紹一下相關問題和實驗,后續會進一步探索更多工作,比如 Document Level 的 Mask 到底對預訓練效果影響有多大,對 Attention 進行優化還能帶來多少提升,如何設計一個比較好的 Packing 策略等?
相關工作可以參考我們之前的文章:
二、Dataset + Dataloader
之前的文章(???LLM 預訓練語料、預處理和數據集索引、加載總結??)中詳細介紹過 Megatron-LM(DeepSpeed-Megatron)中預訓練 Dataset 的存儲格式和 Dataloader 的加載、混合方式。簡單說來,預訓練通常包含很多不同的數據集,每個數據集又包含許多 Document。為了提升訓練效率,在實際訓練的時候一個 Sample(Sequence)里面可能會包含多個不同的 Document(Sample Packing)。比如 8K 的預訓練 Sequence Length,則一個 Sample 可以包含 8 個 1K 的 Document。
如下圖所示,簡單展示了 Megatron-LM 中如何 Packing 多個 Document,實際上就是一個多級的索引。需要說明的是,這里其實會引入很多隨機讀操作,會極大影響讀的性能。不過一般 LLM 計算代價都很高,這里也往往不會導致瓶頸。
三、Attention Mask
對于單個 Document 而言,Decoder Only 的 GPT 模型具有 Causal 特性,也就是每個 Token 不能看到之后的 Token,因此在實際訓練中需要添加 Attention Mask。如下圖所示,這種情況下 Attention Mask 是一個標準的下三角矩陣(Causal Mask),也就是綠色部分為 1,其他部分為 0:
如果一個 Sample 里包含多個樣本,則 Attention Mask 矩陣需要變成如下圖所示的塊對角矩陣形式(Block Diagonal Mask)。比如 Sequence Length 為 16,4 個 Document 的長度分別為 3,4,5,4,則對應 Attention Mask 矩陣如下圖所示,對角線上的 4 個矩陣(紅框)都是標準的下三角矩陣。按照這種方式可以保證和 4 個 Document 單獨作為 Sample 訓練是等價的:
四、Reset Attention Mask
4.1 是否需要
那么在實際使用中是否需要嚴格按照 Block Diagonal Mask 的方式使用呢?答案是否定的,比如 Megatron-LM 可以通過 reset_attention_mask 來控制是使用 Block Diagonal Mask 還是標準的 Causal Mask,默認值為 False。很多模型在預訓練時也會采用默認配置,即使用 Causal Mask。
在浪潮的 Yuan-1.0 報告(“源1.0”大模型技術白皮書)中有提到,為了避免不同 Document 之間的相互干擾而將 reset_attention_mask 設置為 True,也就是 Block Diagonal Mask:
在 Meta 的 LLaMA 3.1 技術報告([2407.21783] The Llama 3 Herd of Models)中也提到,在 LLaMA 3.1 模型的預訓練中會打開這個配置。不過作者也做了說明,對于 8K Sequence Length 的預訓練而言,對模型最終的效果影響不大,對長序列的 Continuous PreTraining 影響比較大:
在 [2402.08268] World Model on Million-Length Video And Language With Blockwise RingAttention 中作者提出了“世界模型”,為了提升超長序列的訓練效率,作者采用了 Sample Packing 的策略,并且做了相關消融實驗。如下圖 Table 10 所示,采用 Naive Packing(不對 Attention Mask 特殊處理)相比使用了 Block Diagonal Mask 的 LWM 的性能會差很多:
PS:當然,目前還沒有更多有關預訓練中是否 reset_attention_mask 的消融實驗,我們后續會進行相關測試。此外,如果采用絕對位置編碼,Position-id 也需要相應的調整,在 Megatron-LM 中對應 reset_position_id 選項。
4.2 性能問題
如下圖為 Megatron-LM/megatron/core/datasets/gpt_dataset.py 中 reset_attention_mask 的實現方式,首先會將 attention_mask 初始化為標準的 Causal Mask 形式,然后從第二個 Document 開始,將之前的 mask 置為 0:
具體來說如下圖所示,初始是一個標準的 Causal Mask 矩陣,然后會將 4x3、5x(3+4) 和 4x(3+4+5) 的區域依次置為 0,之后會變成 Block Diagonal Mask:
實際上我們已經知道這里是標準的 Block Diagonal Mask,可以使用 torch.block_diag() 快速創建。實測當序列比較長時(比如 32K),兩種方式速度可能會差幾十倍,導致 reset_attention_mask 可能成為訓練瓶頸:
除此之外,當序列非常長時,Attention Mask 也會占據很大的存儲空間,為了計算效率,往往會使用整型而不是 Bool 類型。假設以 int8 存儲,32K 序列長度對應的 Mask 大小為 32K * 32K = 1GB,128K 時更是高達 16GB。為了避免顯存浪費,其實不必將其拼成大的 Block Diagonal Mask,而保留幾個小的 Causal Mask 即可。
五、Attention 優化
5.1 FlashAttention
當前 LLM 預訓練基本都會使用 FlashAttention,其對 Casual Mask 的方式進行了優化,如下圖所示,假設 16x16 的 Attention Mask,在計算時按照 4x4 分塊,則可以將其分為 3 種情況:
- 有些塊對應的 Mask 都是 0(紅框右上部分,比如藍框),無需再計算。
- 有些塊中部分 Mask 為 0,部分 Mask 為 1(紅框),需要相應特殊處理。
- 有些塊對應的 Mask 都是 1(紅框左下部分,比如黃框),全部計算即可。?
對于上述 Block Diagonal Mask,依然可以使用 Causal Mask 的方式計算,不過會導致大量的無效計算。幸運的是,FlashAttention V2 支持可變序列長度(Varlen)的 Batching Attention 計算,可以避免 Padding 導致的無效計算。因此也就可以借用這種機制來對 Block Diagonal Mask 進行解構,重新分解為多個 Causal Mask 分別計算,可以避免很多無效計算。如下圖所示,可以將其看成 4 個獨立的 Attention 計算,具體可以參考 FlashAttention Github 上的相關討論:How to implement example packing with flash_attn v2? · Issue #654 · Dao-AILab/flash-attention · GitHub 和 Will attention_mask be extended to 3D? (concatenate short samples for efficient training) · Issue #432 · Dao-AILab/flash-attention · GitHub。
在 GLM-4([2406.12793] ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools)中也應用了 Sample Packing 方案,并且同樣使用了 Block Diagonal Mask 機制來區分不同的 Document。并且作者也是基于 FlashAttention 的 Varlen 功能來實現。
5.2 Pytorch FlashAttention
Pytorch 的 scaled_dot_product_attention 提供了高效的 Attention 實現,也集成了 FlashAttention2 的實現,然而其不支持上述的可變序列長度的功能,導致針對 Block Diagonal Mask 場景時會存在大量的重復計算。
此外,我們在之前的文章中也多次提到,當序列比較短時,Attention 部分計算的占比并不是特別大,因此其中的冗余計算可能對整體訓練速度影響不大;但當序列比較長時,Attention 部分計算的占比會越來越大,冗余計算可能會對訓練速度有比較大的影響,也就需要對其進行優化。
5.3 FlexAttention
Pytorch 在 2.5.0 版本引入了 FlexAttention(FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention),可以很容易支持各種 Attention Mask 變種,比如標準 Causal Mask、Sliding Window + Causal、Prefix Mask 以及 Document Mask(Block Diagonal Mask)等,相比 FlashAttention 也更加的靈活。
我們基于 FlexAttention 進行了相關測試,以驗證使用 Block Diagonal Mask 的性能優勢。首先以兩個 16K Document 拼接為一個 32K Sample 為例,Attention Mask 大概是如下圖所示方式,對應的稀疏度為 74.80%(整個 Mask 中 0 的占比):
如下圖所示我們在 H100 GPU 上進行的 Attention 相關性能測試。可以看出, Pytorch 的 Causal + FlashAttention2 方式確實可以達到非常高的 TFLOPS,明顯高于 FlexAttention。然而,因為 FlexAttention 中避免了很多無效計算,實際的 Forward 和 Backward 時間反而更短:
當然,也并不意外著 FlexAttention 總是更優的,還和 Sample 中 Document 長度有關。如下圖所示為相應測試結果,32K 表示 Sample 中只有一個 Document,2K + 30K 表示 Sample 中有 2 個 Document,一個長度 2K,一個長度 30K。從下圖基本上可以得出這樣一個結論:當 Sample 中最長的 Document 的長度 <= Sequence Length/2 時,使用 FlexAttention 可能會帶來更大的收益:
那么為什么“最長的 Document 的長度 <= Sequence Length/2”時會有收益呢?其實可以簡單從稀疏度的角度考慮:假設 a1 + a2 + a3 + ... + an = S,并且 0 < a1 <= a2 <= a3 <= ... <= an <= S/2,那么可以用數學歸納法得出 (a1)^2 + (a2)^2 + (a3)^2 + ... + (an)^2 <= S^2/2。也就是說,最長的 Document 的長度 <= Sequence Length/2 時,稀疏度會 >= 75%(還要考慮 Causal 特性),相應的 FlashAttention 中至少有一半的冗余計算。
因此,我們也需要充分考慮在長文本訓練過程中短文本的占比,極端情況下訓練數據全部是超長文本,每個 Sample 中都只有一個 Document,Block Diagonal Mask 會退化為 Causal Mask。不過有些時候為了避免模型出現災難性遺忘,也會混合一些短文本數據,或者高質量的預訓練數據,不可避免的會出現冗余計算的問題。
5.4 Sequence Parallel
我們在之前的序列并行文章(???大規模分布式 AI 模型訓練系列——序列并行??)中也提到過,針對長序列場景通常會采用 RingAttention 和 USP 等,然而不管是 RingAttention 還是其 LoadBalance 版本(如下圖 Figure 3 所示)等都沒有太多討論 Sample Packing 的情況。對于 Block Diagonal Mask 場景,其相應的優化,LoadBalance 策略也可能需要對應調整:
在 [2402.08268] World Model on Million-Length Video And Language With Blockwise RingAttention 中作者(也是 RingAttention 的作者)聲稱針對 Block Diagonal Mask 場景對 RingAttention 進行相關優化,但并沒有對比優化前后訓練速度的提升。
PS:整體來說,在各種序列并行技術中更好的兼容 Block Diagonal Mask 場景又會有更多的挑戰,我們留作后續介紹。
六、參考鏈接
- ??https://www.inspur.com/lcjtww/resource/cms/article/2526910/2726086/2022082918565451491.pdf??
- ??https://arxiv.org/abs/2407.21783??
- ??https://arxiv.org/abs/2402.08268??
- ??https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/gpt_dataset.py??
- ??https://github.com/Dao-AILab/flash-attention/issues/654??
- ??https://github.com/Dao-AILab/flash-attention/issues/432??
- ??https://arxiv.org/abs/2406.12793??
- ??https://pytorch.org/blog/flexattention/??
本文轉載自 ??AI閑談??,作者: AI閑談
