Sample Packing 綜述:LLM 效果與效率的 Tradeoff
本文中我們通過幾篇論文來具體介紹 Sample Packing 相關(guān)的方案和對應(yīng)的各種問題,比如 GraphCore 的PackedBert、Meta 的 In-Context-Pretraining、智譜 AI 的 LongAlign、Amazon 的 Fewer Truncations 以及 IBM 的 Packing with FlashAttention。
一、背景
上一篇文章(???Sample Packing:長序列 LLM 訓(xùn)練的 Attention 問題及優(yōu)化??)中我們簡單介紹了 Sample Packing 相關(guān)的問題和部分簡單實驗。本文中我們通過幾篇論文來具體介紹 Sample Packing 相關(guān)的方案和對應(yīng)的各種問題,比如 GraphCore 的PackedBert、Meta 的 In-Context-Pretraining、智譜 AI 的 LongAlign、Amazon 的 Fewer Truncations 以及 IBM 的 Packing with FlashAttention。
二、方法
Sample Packing 可以看成是一個經(jīng)典的 Bin Packing 組合優(yōu)化(Combinatorial Optimization)問題,核心思想是將一組不同體積的物品放入容量固定的箱子中,目標(biāo)是最小化所需箱子數(shù)量。
Bin Packing 是一個典型的 NP-Hard 問題,因此往往需要關(guān)注其效率。對于 N 個 Sample 的數(shù)據(jù)集,通常需要先排序,對應(yīng) O(N*logN) 的時間復(fù)雜度;然后 Packing,需要 O(N*logN) 的時間復(fù)雜度。不過由于 Sample 的長度通常比較有限,可以采用計數(shù)排序的方式進(jìn)行優(yōu)化;此外,整個訓(xùn)練集通常有多個數(shù)據(jù)子集組成,為了保證每個數(shù)據(jù)子集具有不同的采樣權(quán)重,往往一個 Sequence 中的 Sample 來自同一個數(shù)據(jù)子集,一個大的 Bin Packing 問題也變成了多個小的 Bin Packing 問題,可以進(jìn)一步降低復(fù)雜度。本文中我們就不再具體介紹 Bin Packing 的優(yōu)化問題。
如果從物品順序角度考慮,Bin Packing 可以分為兩類:
- Online Bin Packing:物品按順序到達(dá),必須立即決定放入哪個箱子,無法預(yù)知后續(xù)物品的大小。對于 LLM 訓(xùn)練而言,Online 方式依然可以實現(xiàn) Batch 內(nèi)(窗口內(nèi))的亂序,也可以通過梯度累加增加 Batch 的大小。
- Offline Bin Packing:預(yù)先知道所有物品大小,可以全局排序。對于 LLM 訓(xùn)練而言,相當(dāng)于訓(xùn)練前就預(yù)先知道所有序列的長度,并對所有 Sample 打包。
當(dāng)然,Bin Packing 方案也可能有不同的約束條件:
- 一次性 Packing:每個物品只能放進(jìn)一個箱子,不允許拆分物品。對于 LLM 而言,可以理解為 Sample 不能截斷。
- 多維 Packing:物品和箱子不僅有體積,還有其他屬性,比如重量、形狀,需要同時滿足多維約束。對于 LLM,可以約束同一個 Batch 內(nèi)不同 Sequence 的計算量盡量類似,以實現(xiàn)更好的負(fù)載均衡。
整體來說,對于 LLM 訓(xùn)練可以從數(shù)據(jù)分布和計算效率的角度考慮相關(guān)的方案,數(shù)據(jù)分布可能影響模型訓(xùn)練的效果,計算效率會影響訓(xùn)練的速度,往往需要綜合考量。
影響數(shù)據(jù)分布的幾個因素:
- 是否隨機采樣,比如排序機制可能引入分布的不一致。
- 是否交叉污染,比如 Packing 是否采用 Document Level 的 Mask(Block Diagonal Mask,也可以通過 Position ID 區(qū)分)。
- 是否有 Sample 的截斷。
影響計算效率的幾個因素:
- 是否 Padding,Padding 是否參與計算。
- Packing 是否采用 Document Level Mask。
- 是否存在負(fù)載不均衡問題(主要是指稀疏度不同,計算量不同)。
三、GraphCore PackedBert
3.1 概述
GraphCore 在 2021 年 07 月的 [2107.02027] Efficient Sequence Packing without Cross-contamination: Accelerating Large Language Models without Impacting Performance 中已經(jīng)討論了 Sample Packing 導(dǎo)致的交叉污染(Cross Contamination)問題。作者研究了新的 Packing 算法,并通過修改 Attention Mask 和 Position ID 來避免交叉污染,提升 Bert 模型的訓(xùn)練效率。
如下圖 Table 1 所示,不同的 Packing 方式有不同的 EFF(有效率),其中 Baseline 的 None 表示有 50% 左右的 Padding Token,而 SPFHP 和 NNLSHP(本文提出) 能獲得相對比較高的 EFF。
3.2 結(jié)果
如下圖 Figure 4 所示可以看出,需要相應(yīng)修改 Attention Mask 才能保證精度:
如下圖 Figure 5 所示,隨著 Accelerator 的增加,本文的方案能獲得比較穩(wěn)定的加速,非常接近理論速度(最上的藍(lán)線);而非 Padding 的方案可能隨著 Accelerator 的增加出現(xiàn)明顯的降速。
四、Meta In-Context-Pretraining
4.1 概述
Meta 作者在 [2310.10638] In-context Pretraining: Language Modeling Beyond Document Boundaries 也關(guān)注了 Sample Packing 的問題。作者指出,之前的預(yù)訓(xùn)練流程在訓(xùn)練時將隨機的短文檔拼接起來形成輸入上下文,這些文檔之間沒有提供預(yù)測下一個文檔的信號,導(dǎo)致計算效率不高。因此作者提出了新的方案:In-context Pretraining,通過改變文檔的順序,使得每個上下文包含相關(guān)的文檔,從而明確鼓勵模型跨文檔邊界進(jìn)行閱讀和推理。
4.2 方案
如下圖所示,在預(yù)訓(xùn)練前會計算文檔的相似性,在 Packing 時利用上這種相似性,保證 Sequence 中文檔盡可能相關(guān)。
4.3 結(jié)果
作者使用 CommonCrawl 數(shù)據(jù)集,預(yù)訓(xùn)練了 0.3B 到 7B 參數(shù)量的多個模型,并在多種任務(wù)上評估,包括上下文學(xué)習(xí)、閱讀理解、對先前上下文的忠實度、長上下文推理和檢索增強等。與使用標(biāo)準(zhǔn)方法預(yù)訓(xùn)練的模型相比,In-context Pretraining 方法訓(xùn)練出的模型(ICLM)顯示出顯著的性能提升。
如下圖 Figure 3 所示,本文 ICLM 訓(xùn)練的模型獲得了更低的困惑度:
如下圖 Table 1 所示,基于 ICLM 訓(xùn)練的模型在下游 In-context Learning 任務(wù)上也獲得了更好的效果:
五、智譜 LongAlign
5.1 概述
智譜 AI 在 [2401.18058] LongAlign: A Recipe for Long Context Alignment of Large Language Models 中討論了部分 Sample Packing 相關(guān)問題。如下圖 Figure 3 左圖所示,Sequence 的長度各不相同,從 0 - 60K,如果采用 Naive Batching 方式,會導(dǎo)致明顯的 Bubble 問題(雖然 NoPadding 技術(shù)可以避免重復(fù)計算,但是如果采用 Data Parallelism 方式,比較快的設(shè)備需要等待比較慢的設(shè)備計算完成)。為了解決效率和效果問題,作者提出了 3 種解決方案:Packing、Loss Weighting 和 Sorted Batching。
5.2 Packing
如 Figure 3 右上圖所示,就是我們之前介紹的 Sample Packing:將不同的 Sample 拼接在一個 Sequence 里,并且保證盡可能接近 Max Sequence Length,末尾的部分 Token 進(jìn)行 Padding。然后通過 Block Diagonal Attention Mask 來區(qū)別不同的 Sample,以避免 Sample 之間的交叉污染,也就是 Document Level Attention。
PS:作者介紹,這里同樣是使用了 FlashAttention2 的 Varlen 特性。
5.3 Loss Weighting
假設(shè)訓(xùn)練時的 Batch Size 為 K,總共包含 M 個 Sample,第 i 個 Sample 的 Token 數(shù)為 Ni,則對應(yīng)的 Loss 如下圖所示:
然而,增加 Sample Packing 之后會引入一個問題,如下圖所示,一個 Sequence 中的不同 Sample 會被看成一個 Sample 來計算損失。當(dāng)有些 Sample 比較長,其對應(yīng)的 Token 很多,那么這個 Sample 對 Loss 的貢獻(xiàn)就更大,模型可能會在訓(xùn)練時更傾向于優(yōu)化長 Sample 的表現(xiàn),進(jìn)而可能會導(dǎo)致對短 Sample 的學(xué)習(xí)有所欠缺。
為了解決這個問題,作者提出 Loss Weighting,也就是對不同 Sample 的 Loss 加權(quán)。如下圖所示,保證其和上述公式(2)等價。作者聲稱可以在下游任務(wù)上帶來 10% 左右的效果提升。
PS:不過這里還會引入另外一個問題,因為采用了 Sample Packing,那么實際上不同 Step 中 Sample 的個數(shù)在不斷變化。比如一個 Batch 里都是短 Sample,那么對應(yīng)的 M 會比較大;如果一個 Batch 里都是長 Sample,相應(yīng)的 M 會比較小。這樣可能引入兩個問題,1. 相當(dāng)于一個 Batch Size 在不斷變化;2. 同樣一個 Sample 可能會因為順序的原因被賦予不同的權(quán)重。因此需要盡量保證Batch 中 Sample 的平均個數(shù)比較穩(wěn)定。
5.4 Sorted Batching
如上圖 Figure 3 右下圖所示,可以將所有 Sample 進(jìn)行排序,在組 Batch 時盡量保證一個 Batch 中的 Sample 長度相同(沒有 Packing)。這樣可以保證不同設(shè)備的計算盡可能的均衡。然而,這種方式也不可避免地引入不同 Batch 的數(shù)據(jù)分布的偏差,有些 Batch 都是長序列,有些 Batch 都是短序列,對于 SGD 優(yōu)化來說可能并不友好。不過作者發(fā)現(xiàn)這種方式可以顯著加快訓(xùn)練速度,而不會對效果產(chǎn)生明顯的負(fù)面影響。這可能是因為使用了大的梯度累加(Micro Batch 中長度類似,但整個 Batch 中包含各種長度的 Sample)。
PS:這種方式也就對應(yīng) Transformer 等工作中的 LengthGroupedSampler,如下圖所示,排序后可以有效降低無效計算(圖片參考 數(shù)據(jù)分組— XTuner 0.1.23 文檔)。
5.5 結(jié)果
如下圖 Figure 5 所示,作者在 8xA800 GPU 上進(jìn)行速度對比,可以看成,Packing 和 Sorted Batching 相比 Naive Batching 都有 2x-3x 的加速:
如下圖 Table 3 所示,作者基于 ChatGLM3-6B-64K 和 LLaMA-2-7B-64K 進(jìn)行了相關(guān)效果驗證。可以看出,Loss Weighting 在 LongBench-Chat 上能帶來 5%-10% 的提升,但是在其他任務(wù)上并不明顯,并且這些方法看著都不是特別魯棒。
5.6 Packing 負(fù)載均衡
智譜 AI 在訓(xùn)練 GLM4(GLM Long: Scaling Pre-trained Model Contexts to Millions | by ChatGLM | Medium) 模型時進(jìn)一步解決了 Packing 帶來的負(fù)載均衡問題。如下圖所示,雖然都 Packing 到了相同的長度,但是由于其中的 Sample 個數(shù)、長度不同,導(dǎo)致其稀疏度差距很大,計算量也相應(yīng)差距很大。如果它們在不同的設(shè)備上執(zhí)行,同樣會存在計算不均導(dǎo)致的 Bubble 問題。
如下表所示,我們在上一篇文章中的實驗也能說明這個問題,隨著 Sequence 中 Sample 分布的不同,計算的耗時甚至可能差 10x:
如下圖所示,作者發(fā)現(xiàn)訓(xùn)練中每個 Step 的時間存在較大幅度的波動,這種現(xiàn)象在短文本的 Packing SFT 中并不明顯。這是因為短文本時 Attention 的計算占比并不高,而超長文本訓(xùn)練中會尤其明顯。
為了解決上述問題,作者進(jìn)一步提出了 Sorted Packing。具體來說,作者在構(gòu)建 Batch 數(shù)據(jù)時考慮了計算復(fù)雜度,以確保每個 Batch 中計算復(fù)雜度相似,從而減少 Bubble 時間。(PS:這里需要注意,計算復(fù)雜度不等于執(zhí)行速度,如果能針對預(yù)估計算速度來打包也許能獲得更優(yōu)的效果)
作者指出也使用了 layer accumulation 技術(shù)(PS:上述介紹的梯度累加?)來避免排序?qū)е碌钠脝栴}。
六、Amazon Fewer Truncations
6.1 概述
在 [2404.10830] Fewer Truncations Improve Language Modeling 中,作者探討了數(shù)據(jù)截斷問題對模型效果的影響。作者指出,截斷會損害數(shù)據(jù)完整性,從而阻礙模型學(xué)習(xí)基于完整上下文撰寫邏輯連貫且事實一致的內(nèi)容的能力。
為了解決這個問題,作者提出了 Best-fit Packing,通過長度感知組合優(yōu)化來 Packing,可以完全消除不必要的截斷,同時基本不影響訓(xùn)練效率。通過文本和代碼預(yù)訓(xùn)練的實驗結(jié)果表明,提出的方法可以在閱讀理解上相對提升 4.7%,上下文跟隨提升 16.8%,程序生成提升 9.2%。此外,也可以有效減少 58.3% 的封閉域幻覺問題。
PS:論文中對比實驗時 Baseline 的 Concatenation 方案中有截斷,并且沒有使用 Block Diagonal Mask;而 Best-Fit Packing 使用了 Block Diagonal Mask,且沒有截斷。
6.2 方案
如下圖 Figure 1 所示為本文 Best-Fit Packing 與傳統(tǒng)方案的對比。
- 右圖所示為傳統(tǒng)方案:可以理解為將所有樣本排成一行,然后按照 Max Sequence Length 進(jìn)行切割,會導(dǎo)致大量 Sample 被截斷。
- 左圖為本文的方案:首先將所有 Document 按照 Max Sequence Length 截斷,然后使用 Best-Fit Decreasing 算法來進(jìn)行組合優(yōu)化(Bin Packing 優(yōu)化)。?
如下圖 Table 2 所示,使用本文的 Best-Fit Packing 可以保證只增加不到 0.003% Sequence 數(shù)量,對訓(xùn)練效率的影響也就微乎其微。
6.3 結(jié)果
如下圖 Table 3 所示,作者訓(xùn)練了 3 個不同規(guī)模、序列長度的模型:
如下圖 Table 4 所示,提出方案訓(xùn)練出的模型的閱讀理解能力有明顯提升:
如下圖 Table 9 所示,幻覺問題也可以明顯降低:
如下圖 Table 10 所示,作者也針對 Attention Mask 進(jìn)行了相關(guān)消融實驗,可以看出原始 Concatenation 方案加入 Block Diagonal Mask 后也有一定的提升。可以證明 Attention Mask 和截斷都會對效果有一定影響,不過 Packing(避免截斷) 的影響似乎更大一些。
七、IBM Packing with FlashAttention2
7.1 概述
在 [2407.09105] Enhancing Training Efficiency Using Packing with Flash Attention 中,作者總結(jié)了不同 Packing 策略、Mask 方式及與 FlashAttention 結(jié)合的優(yōu)勢。此外,作者也將相關(guān)工作提交到了 Huggingface Transformer 中,提供了新的 DataCollatorWithFlattening,具體可以參考:通過打包Flash Attention 來提升Hugging Face 訓(xùn)練效率。
7.2 相關(guān)方案
如下圖 Table 1 所示,作者分析了不同的 Packing 方案以及它們的影響,具體包含如下幾種方式:
- RandomSampling + Padding:最傳統(tǒng)的隨機采樣,然后 Padding 的方式。存在冗余計算,并且占比很高。
- GroupByLength+Padding:先排序,然后盡量保證每個 Batch 中的序列長度接近。可以減少 Padding 的占比。
- RandomSampling + PosID:隨機采樣,但是不 Padding,而是通過 PosID 支持變長序列。幾乎沒有冗余計算,但可能存在明顯的負(fù)載不均衡(計算量)。
- FixedLengthPacking:隨機采樣,隨機 Packing,并且最后一個Sample 可能截斷,保證填滿 Max Sequence Length。沒有區(qū)分不同 Sample,也就是 Causal Mask,沒有冗余計算,并且負(fù)載很均衡。
- FixedLengthPacking + PosID:相比FixedLengthPacking多了 PosID,也就是可以區(qū)分不同 Sample,對應(yīng) Block Diagonal Mask。但依然會存在末尾截斷,并且可能負(fù)載不均衡。
- MultiPack + PosID:使 Sequence 中的數(shù)據(jù)盡量接近 Batch 的 Max Sequence Length,降低 Sequence 中的長度不均衡,可以參考GitHub - imoneoi/multipack_sampler: Multipack distributed sampler for fast padding-free training of LLMs。需要對數(shù)據(jù)進(jìn)行排序。
- SortedPacking + PosID:通過排序,使同一個 Batch 中的計算復(fù)雜度盡量接近。可以盡可能降低計算負(fù)載不均衡問題。
- RandomPacking + PosID:與FixedLengthPacking + PosID相比主要的區(qū)別就是最后一個 Sample不截斷,可能存在部分 Bubble。?
7.3 結(jié)果
作者通過微調(diào)任務(wù)對比了一系列模型使用不同方案的效果和速度,其中 Max Sequence Length(msl)為 4096,每個 GPU 的 Mini-Batch Size 為 4。對應(yīng)配置如下:
- no:表示最原始的RandomSampling + Padding,可以作為基線,冗余計算比較多,速度最慢,但是效果有保障。
- yes:表示FixedLengthPacking,存在交叉污染。
- flat:表示訓(xùn)練前 Offline Packing 好,也就是引入了排序,并且使用 PosID 實現(xiàn) Document Level Mask。
- mini:表示訓(xùn)練中 mini batch 的 Online Packing,和 Random 類似,并且使用 PosID 實現(xiàn) Document Level Mask。
如下圖所示(PS:這里只是部分結(jié)果,全量請參考論文,結(jié)論基本一致,可以看出 yes 和 flat 都會對精度(VLoss)有比較大的影響,但速度(Tok/s)確實快了很多,可以達(dá)到 Baseline 的 3x-4x;而 mini 可以在保證精度的情況實現(xiàn)實現(xiàn) 2x 左右加速。
作者使用 FLAN_20k 數(shù)據(jù)集在 Mistral-7B 上針對之前提到的幾種方案做了更多實驗(PS:這里只是部分,gas 表示梯度累加次數(shù)),看起來 Packing 的方案都能獲得不錯的速度,但是精度的影響因素就比較多。但整體來說 RandomPacking + PosID 的方案還不錯,看著似乎也不能有太多的梯度累加。
八、參考鏈接
- ??https://arxiv.org/abs/2107.02027??
- ??https://arxiv.org/abs/2310.10638??
- ??https://arxiv.org/abs/2401.18058??
- ??https://xtuner.readthedocs.io/zh-cn/latest/acceleration/length_grouped_sampler.html#length-grouped-sampler??
- ??https://medium.com/@ChatGLM/glm-long-scaling-pre-trained-model-contexts-to-millions-caa3c48dea85??
- ??https://arxiv.org/abs/2404.10830??
- ??https://arxiv.org/abs/2407.09105??
- ??https://huggingface.co/blog/zh/packing-with-FA2??
- ??https://github.com/imoneoi/multipack_sampler??
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
