LLM 預訓練加速的新方法:八種模型增長方案總結
一、背景
LLM 的涌現能力依賴于其模型規模的增長,而 Scaling Law 也在推進 LLM 朝著越來越大的方向發展。然而,LLM 預訓練的成本非常高,尤其是其與模型規模、數據量成正比,一個千億參數量的模型往往需要幾千個 GPU 訓練幾個月的時間。加速 LLM 預訓練也因此稱為一個非常有前景的研究方向。
當前常見的優化方案為優化分布式策略,通信,以及訓練穩定性等。與此同時,很多時候大家都會訓練各種規模的 LLM,例如 LLaMA 系列模型。也有許多工作在嘗試利用已經訓練好的較小 LLM 的權重,然后通過漸進式學習的方法加快較大 LLM 模型的訓練,比如使用訓練好的 LLaMA3 8B 模型來初始化 LLaMA3 30B 模型并繼續訓練,這種技術可以稱為模型增長(Model Growth)。
本文中,我們將總結一系列模型增長的方案,以便對模型增長的發展以及各種方案有一個更清晰的認識。具體來說,我們將分別介紹 Net2Net,StackBert,bert2BERT,LiGO,LEMON、MSG 和 Mango 幾種方案。
PS:其實當前很多 MoE 模型的訓練正是屬于模型增長的范疇。比如說,Mixtral-8x7B( Mixtral of experts | Mistral AI | Frontier AI in your hands) 是從 Mistral-7B 初始化而來,專家直接拷貝了 Mistral-7B 中的 FFN。此外,也可以使用 Mistral-7B 針對不同任務微調 FFN 后再來作為專家。
相關工作可以參考:
- ???Meta 發布 BTX:模型融合與 MoE 的結合???
- ???MoE 系列論文解讀:Gshard、FastMoE、Tutel、MegaBlocks 等???
- ???7 種 LLM 融合擴展總結:LLaMA-Pro、SOLAR、MoE 等???
二、Net2Net
2.1 摘要
Net2Net([1511.05641] Net2Net: Accelerating Learning via Knowledge Transfer)是模型增長領域的開創性工作,作者是陳天琦和 GAN 的作者 Ian Goodfellow 等。作者提出了基于存在的小模型來初始化大模型,并繼續訓練的方案。其主要包含寬度(widthwise)擴展和深度(depthwise)擴展兩個方面。如下圖 Figure 1 所示為 Net2Net 與傳統模型訓練的差異:
- 傳統 Workflow:訓練小模型和大模型時分別設計模型,并從 0 開始訓練。
- Net2Net Workflow:訓練小模型從 0 開始,訓練大模型時用已經訓練好的小模型來初始化大模型,然后繼續訓練。?
Function Preserving Initialization(FPI):目標是給定一個源模型,用它初始化目標模型,能保證給定相同的輸入,目標模型和源模型有相同的輸出。
2.2 寬度擴展(Net2WiderNet)
寬度擴展,主要是指在模型的層內擴展,增加層的寬度。如下圖所示:
- 左圖為原始網絡,輸入包含 x[1] 和 x[2],輸出為 y,其包含兩個線性的隱藏單元 h[1] 和 h[2]。
- 右圖為擴展之后的網絡,新增一個隱藏單元 h[3],其對應的權重直接拷貝 h[2],這樣 h[3] 和 h[2] 就是相等的。同時 y 中新增一個 h[3] 的權重,同時 h[2] 和 h[3] 的權重都變為 f/2。最終 y = e*h[1] + f*h[2] = e*h[1] + f/2*h[2] + f/2*h[3]。?
2.3 深度擴展(Net2DeeperNet)
深度擴展,主要是指增加模型的層數。如下圖所示:其主要是在網絡中插入初始化的 Identity Mapping 層,其輸入輸出是相同的,保證插入后依然是等價的。
三、StackingBert
3.1 StackingBert-v1
在 Efficient Training of BERT by Progressively Stacking 中,作者提出了通過堆疊 Bert 模型的 Transformer Layer 來擴展模型規模的方法。Transformer 模型的 Decoder 和 Encoder 中不同層之間結構完全一樣,輸入和輸出 Shape 相同,這就為通過 Copy 來復制 Layer 提供了很大的遍歷。如下圖 Figure 3 所示,假設已經訓練好 L 層 Encoder 的 Bert 模型,直接通過拷貝即可擴展為 2L 層,具體來說,第 0 層和 第 L 層完全一樣,第 i 層和第 L+i 層完全一樣。然后繼續訓練即可以得到 2L 層的 Bert 模型:
3.2 StackingBert-v2(MSLT)
在 [2011.13635] Progressively Stacking 2.0: A Multi-stage Layerwise Training Method for BERT Training Speedup 中,作者進一步對 StackingBert-v1 進行了擴展。具體來說,將一個 N 層 Encoder 的模型分 K+1 次訓練,第一次訓練一個 N/k 層的 Bert 模型,然后每次擴展 N/k 層并且進行訓練。其中綠色為凍結的層,紅色為訓練的層。也就是每次擴展后只訓練新擴展的層,全部擴展完之后再解凍所有層繼續訓練:
四、bert2BERT
4.1 摘要
在上述 StackingBert 的兩個版本中,都是在深度上擴展模型規模。然而,實際上 Bert Large 相比 Bert Base 除了深度更深外,每一層也更寬,如下圖所示,其 hidden size 從 768 擴展到 1024,相應的 head 數也有所增加。
在 [2110.07143] bert2BERT: Towards Reusable Pretrained Language Models 中,作者提出 bert2BERT,其同樣是為了利用已有的小的預訓練模型來加快更大的預訓練模型的速度,降低訓練成本。作者測試發現,對于 Bert Base 和 GPT base,通過重用一半左右大小的模型,可以節約 45% 和 47% 的計算成本。
4.2 FPI 矩陣擴展
在模型寬度上,其最主要的就是矩陣乘法計算,大部分的模型參數都是一個權重矩陣,對應 h=W*x。因此,最常見的就是矩陣擴展,如下圖 Figure 3 所示,一般分為兩步:
- 第一步:輸入擴展,新增 x3,相應的參數矩陣要增加一列。
- 第二步:輸出擴展,新增 h3,相應的參數矩陣要增加一行。?
如下圖 Figure 4 所示為一個滿足 FPI 的模型寬度擴展示例,其分為三步:
- 第一步:輸入擴展,在x1, x2 的基礎上新增 x3,不過 x3 是直接 copy 的 x1。因此,直接在權重矩陣上擴展一列,并將第一列的權重和第三列均分。此時 h1 = o/2*x1+p*x2+o/2*x1 = o*x1+p*x2;h2 同理,因此可以保證 FPI。
- 第二步:輸出擴展,在h1, h2 的基礎上新增 h3,不過 h3 是直接 copy 的 h2。此時,需要在權重矩陣擴展一行,新的第三行直接復制第二行即可保證 h3=h2。
- 第三步:上述擴展了第一層的輸出,也就是擴展了第二層的輸入,要保持輸出結果不變,采用第一步的方法在第二層的權重矩陣擴展一列即可。?
4.3 AKI 矩陣擴展
為了加快訓練的收斂速度,作者進一步提出了 Advanced Knowledge Initialization(AKI),其不僅考慮當前層的參數,也考慮下一層的參數。這樣做是因為之前其它的工作中發現相鄰的 Transformer 層比較相似,因此綜合考慮兩個相鄰層并不會對模型效果有太大影響。當然,這也會打破 FPI 約束。
如下圖 Table 2 所示,作者通過實驗對比了不同方案的效果,首先 AKI 會比 FPI 收斂更快,計算資源更少。此外,作者進一步加上了兩階段預訓練,也就是 bert2BERT,訓練成本節約 45.2%,明顯優于 StackBERT 和 MSLT:
五、LiGO
5.1 摘要
在 [2303.00980] Learning to Grow Pretrained Models for Efficient Transformer Training 中,作者提出了通過線性投影來使用小模型初始化大模型的方案。具體來說,作者將線性變換分解為線性寬度增長算子和線性深度增長算子的組合,并進一步采用這些增長算子的 Kronecker 分解來嵌入結構知識。在語言模型和視覺 Transformer 模型上的實驗表明,采用這些線性增長算子(Linear Growth Operator,LiGO)可以節約高達 50% 的訓練成本,優于之前的方法。
5.2 方法
LiGO 的實現方式如下圖 Figure 1 所示,作者定義了一個線性映射函數 M,其可以將小模型的參數 ? 映射為大模型參數 ?new。由于直接學習 M 的代價非常非常高,因此將其分解為寬度增長算子 Rwidth 和深度增長算子 Ldepth 的組合。為了減少可訓練參數的數量并嵌入結構知識,進一步通過 Kronecker 積來分解 Rwidth 和 Ldepth,這樣每個增長算子都可以表示為較小矩陣的 Kronecker 積。
Kronecker 積:假設矩陣 A 的大小為 m×n,矩陣 B 的大小為 p×q。那么 A 和 B 的 Kronecker 積 A?B 的大小將是 (m?p)×(n?q)。它的元素由以下方式確定:
5.3 實驗和結果
其訓練過程分為 3 步:
- 為了訓練大模型,首先需要學習線性映射 M。作者通過 100 次的梯度迭代來優化 M,這個過程相對于整個訓練來說代價很小。
- 然后使用 M 將小模型的參數 ? 映射為大模型參數 ?new。
- 接著使用常規的訓練方式來訓練大模型。
作者在 Bert 模型上進行了相關實驗,如下圖 Figure 2 和 Table 1 所示,LiGO 相比從頭訓練可以節約 40.7% 的訓練成本。(PS:比較奇怪的是,作者測試 bert2BERT 的收益反而比 StackBERT 和 MSLT 更低,這與 bert2BERT 論文不太相符。)
六、LEMON
6.1 摘要
在 [2310.07999] LEMON: Lossless model expansion 中作者提出了 LosslEss MOdel expansioN(LEMON),對之前的模型擴展方案進行了增強,主要聚焦在無損擴展。其在 Vision Transformer 模型上可以減少 56.7% 的訓練代價,BERT 上可以減少 33.2%。如下圖 Table 1 所示為 LEMON 與其它方法的對比:
6.2 方法
作者提出的 LEMON 方法包含三個基本組件:
- 非均勻的通用無損寬度擴展。
- 針對 LayerNorm 的平均寬度擴展。
- 無損深度擴展。
6.2.1 無損寬度擴展
主要包含兩個部分,一個是 MLP 擴展,一個是 MHA 擴展:
- MLP 擴展:如下圖 3(a) 所示,其與之前寬度擴展的主要不同就是:之前的擴展中一般 a=β=1/2,其復制出來的神經元和原始神經元提取完全相同的表征,而在這里的 a≠β。
- MHA 擴展:如下圖 3(b) 所示,在 Transformer 模型中,模型變寬通常也意味著 MHA 中 Head 的增加,這里作者采用直接拷貝整個 Head 的方案。?
6.2.2 平均寬度擴展
如下圖 Figure 4 所示,LayerNorm 的平均寬度擴展就是對于新增的行直接使用之前行的平均,同時對 LayerNorm 中的 μ 增加縮放因子。以此可以保證最終的分布不變,滿足 FPI。
具體的證明如下圖所示,其要點是通過擴展均值可以保證擴展后均值不變,方差有一個固定的縮放因子,以此可以保證擴展后的輸出位置還為 0,擴展前的位置通過縮放因子也可以還原:
如下圖所示為 LayerNorm 擴展和 MHA 擴展的結合:
6.2.3 無損深度擴展
深度擴展基本上都是通過堆疊 Layer 的方式實現,要想保證堆疊后的無損,需要保證堆疊的層的輸入和輸出一樣,也就是等效于 Identify Layer。幸運的是,Transformer Layer 中的 MHA 和 MLP 層都有殘差連接,如下圖 Figure 2 中的紅框所示,因此只要保證新增層中 MHA 和 MLP 的輸出為 0 就可以保證無損。
具體的實現方式有兩種:
- 對應下圖 Figure 5(b) ,直接將最后一個全連接層置為全 0 即可。
- 對應下圖 Figure 5(c),同樣針對最后一個全連接層處理,將對應同一個神經元的權重設置為相反值,保證累積后和為 0。?
6.3 實驗&結果
如下圖 Figure 7 所示,作者與之前的模型擴展及模型蒸餾方案進行了對比,本文提出的 LEMON 在 ViT、Bert 模型上都收斂更快:
七、MSG
7.1 摘要
在 [2305.02869] Masked Structural Growth for 2x Faster Language Model Pre-training 中,智源(BAAI)等作者將漸進式增長分為兩個方面:
- 確定最優的增長規劃(Growth Schedule):主要是探討不同維度(比如深度、寬度)對增長效率的影響。
- 設計高效的增長算子(Growth Operator):當前的方法主要依賴新權重的初始化來繼承知識,并且很多只實現了非 FPI 的方案,從而限制了訓練的進一步提升。
為了解決以上問題,作者提出了 Masked Structural Growth(MSG),包括:
- 涉及所有可能維度的增長規劃。
- 與新權重初始化無關的嚴格滿足 FPI 的增長算子。
如下圖 Table 1 所示為 MSG 與其他方案的對比,可見 MSG 支持更多的維度,都滿足 FPI,可以獲得高達 2.2x 的加速:
7.2 增長算子
如下圖 Figure 1 所示,之前的各種方案(比如 Net2Net)都是通過特殊的初始化方式來保證盡量滿足 FPI。而本文的 MSG 的核心思路就是不管是什么樣的初始化,都通過 Mask 的方式讓新增的部分為 0,以滿足 FPI。同時,在訓練中逐漸增大 Mask,直到其 mask=1,此時就可以刪除 Mask。針對 MLP,LayerNorm,MHA 以及殘差連接都可以通過 Mask 方式實現。
7.3 增長規劃
對于 Transformer 模型,其決定模型規模的超參數主要有 4 個:hedden_dim, ffn_dim, head_num, layer_num。除了 layer_num 為深度擴展外,其它 3 個都是寬度擴展。在模型擴展過程中,可以一次擴展所有參數,也可以逐個擴展,但是又會存在擴展順序的問題,因此要找到一個最優的擴展方案是一個很有挑戰的工作。如下圖 Table 2 所示,作者針對 Bert 和 GPT-2 模型制定了不同的擴展規劃,并通過實驗驗證了各自的影響:
7.4 實驗
7.4.1 bert2BERT 對比
作者首先與 bert2BERT,以及從頭訓練進行了效果和速度的對比,可以看出,MSG 基本實現了效果和速度的最優,其最多可以實現 2.2x 加速:
7.4.2 LLM 預訓練對比
作者也進一步驗證了在 LLM 預訓練上的效果,具體來說,作者驗證了 LLM 從 16B 擴展到 51B 再擴展到 101B 的方案。其使用 24 DGX-A800(8x80G)機器,共192 A800,先在 16B 規模訓練 245.37B Token,然后在 51B 規模訓練 39.64B Token,最后在 101B 規模訓練 26.54B Token,總共訓練了 21.54 天,花費 100K 美元。如下圖 Figure 4 所示為其訓練的 Loss 曲線,作者與相似數據規模的 GLM-130B 模型對比,其只使用 10% 的訓練成本即可以達到 80% 的性能(具體可以參考作者的論文 [2309.03852] FLM-101B: An Open LLM and How to Train It with $100K Budget):
八、Mango
8.1 摘要
在 [2310.10699] Reusing Pretrained Models by Multi-linear Operators for Efficient Training 中,作者肯定了 bert2BERT 和 LiGO 中通過小模型來初始化大模型的方法,但也提出這些方法可能只考慮了局部相關性,而忽略了整個模型的相關性,這種部分映射的方法可能限制擴展模型的加速能力。因此,本文中,作者提出了一種將目標模型的每個權重與源模型的所有權重線性關聯的方案,并利用多線性算子(Multi-Linear Operator,Mango)來降低計算和空間復雜度。實驗結果表明,從 DeiT-small 到 DeiT-base,可以節省 76% 的計算成本,比 bert2BERT 和 LiGO 分別高出 12.0% 和 20.7%。
8.2 方案
本文的方案概覽如下圖 Figure 4 所示:
- 左圖:Transformer Layer 中的參數表示。
- 右上:一個 Transformer 模型中的全部參數可以用一個大的 Tensor M 表示,并由 B、I、O、L 這 4 個超參數決定。
- 右下:本文的方案,也就是 Mango 算子,可以學習一個線性映射函數 S,將小模型權重 M1 映射為大模型權重 M2。顯然,S 的空間極大,因此作者使用張量環矩陣乘法算子(Tensor Ring Matrix Product Operator,TR-MPO)來降低計算和空間復雜度,將其分為 4 個較小的張量,并通過 Rank 連接,通過訓練來學習到 4 個小的張量后就可以用于構建 M2:
- SB:表示在同一層中參數之間相互作用。
- SI和SO:分別表示輸入和輸出維度上的轉換。
- SL:表示層與層之間的關系。
- R:表示 S 的低秩級別。?
8.3 實驗和結果
如下圖 Table 1 所示,作者將 Mango 與之前的 bert2BERT 和 LiGO 的復雜度、能力進行了對比:
如下圖 Figure 7 所示,作者與 StackBERT、bert2BERT 以及 LiGO 的訓練效果進行了對比,可以看出,本文提出的 Mango 收斂更快,在 DeiT 上可以節約 76.4% 的成本:
九、參考鏈接
- ???https://mistral.ai/news/mixtral-of-experts/???
- ???https://arxiv.org/abs/1511.05641???
- ???https://proceedings.mlr.press/v97/gong19a.html???
- ???https://arxiv.org/abs/2011.13635???
- ???https://arxiv.org/abs/2110.07143???
- ???https://arxiv.org/abs/2303.00980???
- ???https://arxiv.org/abs/2310.07999???
- ???https://arxiv.org/abs/2305.02869???
- ???https://arxiv.org/abs/2309.03852???
- ???https://arxiv.org/abs/2310.10699????
