字節(jié)豆包大模型團隊突破殘差連接局限!預訓練收斂最快加速80%
自從 ResNet 提出后,殘差連接已成為深度學習模型的基礎(chǔ)組成部分。其主要作用是 —— 緩解梯度消失問題,使得網(wǎng)絡(luò)的訓練更加穩(wěn)定。
但是,現(xiàn)有殘差連接變體在梯度消失和表示崩潰之間存在一種 “蹺蹺板式” 的權(quán)衡,無法同時解決。
為此,字節(jié)豆包大模型 Foundation 團隊于近日提出超連接(Hyper-Connections),針對上述 “蹺蹺板式” 困境,實現(xiàn)了顯著提升。
該方法適用于大規(guī)模語言模型(LLMs)的預訓練,在面向 Dense 模型和 MoE 模型的實驗中,展示了顯著性能提升效果,使預訓練收斂速度最高可加速 80%。
研究團隊還發(fā)現(xiàn),超連接在兩個小型的視覺任務(wù)中表現(xiàn)同樣優(yōu)異,這表明,該方法在多個領(lǐng)域有廣泛的應(yīng)用前景。
- 論文標題:Hyper-Connections
- 論文鏈接:https://arxiv.org/pdf/2409.19606
1. 超連接的核心思想
前文提及,殘差連接的兩種主要變體 Pre-Norm 和 Post-Norm 各自都有其局限性,具體體現(xiàn)如下:
- Pre-Norm:在每個殘差塊之前進行歸一化操作,可有效減少梯度消失問題。然而,Pre-Norm 在較深網(wǎng)絡(luò)中容易導致表示崩潰,即深層隱藏表示過于相似,從而削弱了模型學習能力。
- Post-Norm:在殘差塊之后進行歸一化操作,有助于減少表示崩潰問題,但也重新引入梯度消失問題。在 LLM 中,通常不會采用此方法。
超連接的核心思路在于 —— 引入可學習的深度連接(Depth-connections)和寬度連接(Width-connections)。
從理論上,這使得模型不僅能夠動態(tài)調(diào)整不同層之間的連接強度,甚至能重新排列網(wǎng)絡(luò)層次結(jié)構(gòu),彌補了殘差連接在梯度消失和表示崩潰(Representation Collapse)之間的權(quán)衡困境。
深度連接與寬度連接
起初,該方法會將網(wǎng)絡(luò)輸入擴展為 n 個隱向量(n 稱作 Expansion rate)。之后每一層的輸入都會是 n 個隱向量,超連接會對這些隱向量建立以下兩類連接:
- 深度連接(Depth-Connections):這些連接類似于殘差連接,只為輸入與輸出之間的連接分配權(quán)重,允許網(wǎng)絡(luò)學習不同層之間的連接強度。
- 寬度連接(Width-Connections):這些連接使得每一層多個隱藏向量之間可進行信息交換,從而提高模型表示能力。
靜態(tài)與動態(tài)超連接
超連接可以是靜態(tài)的,也可以是動態(tài)的。
其中,靜態(tài)超連接(Static Hyper-Connections, SHC)意味著連接權(quán)重在訓練結(jié)束后固定不變。而動態(tài)超連接(Dynamic Hyper-Connections, DHC)則對應(yīng)連接權(quán)重可根據(jù)輸入動態(tài)調(diào)整。實驗表明,動態(tài)超連接效果更好。
2. 技術(shù)細節(jié)
超連接(Hyper-connections)
首先,考慮第 k 層的輸入隱藏向量,網(wǎng)絡(luò)的初始輸入為
,并將其復制 n 次,形成初始的超隱藏矩陣(Hyper Hidden Matrix):
這里,n 稱為擴展率(Expansion Rate)。在第 k 層,輸入是上一層的超隱藏矩陣,即:
對最后一層的超隱藏矩陣逐行求和,得到所需的隱藏向量,并通過一個投影層輸出網(wǎng)絡(luò)最終的結(jié)果(在 Transformer 中即為歸一化層和解嵌入層)。
為了簡化后續(xù)分析的符號表示,作者省略層索引,直接將超隱藏矩陣表示為:
超連接可以用一個矩陣來表示,對于擴展率為 n 的情況,超連接矩陣 HC 如下:
考慮一層網(wǎng)絡(luò),它可能是 Transformer 中的 attention 層或者是 FFN 層。超連接的輸出
可以簡單地表示為:
也就是說,用 作為權(quán)重對輸入
進行加權(quán)求和,得到當前層的輸入
:
同時,用于將
映射到殘差超隱藏矩陣
,表示如下:
最終的輸出表達式為:
偽代碼如下:
動態(tài)超連接的實現(xiàn)
超連接矩陣 的元素可以動態(tài)依賴于輸入
,動態(tài)超連接的矩陣表示為:
同樣,給定層 和輸入
,可以得到動態(tài)超連接的輸出:
在實際操作中,團隊結(jié)合了靜態(tài)和動態(tài)矩陣來實現(xiàn)動態(tài)超連接,動態(tài)參數(shù)通過線性變換獲得。
為了穩(wěn)定訓練過程,團隊在線性變換前引入歸一化,并在其后應(yīng)用 tanh 激活函數(shù),通過一個可學習的小因子進行縮放。動態(tài)參數(shù)的計算公式如下:
實驗表明,動態(tài)超連接在語言建模任務(wù)中優(yōu)于靜態(tài)超連接。
3. 為什么使用超連接(Hyper-Connections)
研究團隊認為,殘差連接的兩種變體,即前歸一化(Pre-Norm)和后歸一化(Post-Norm),可以被視為不可訓練的超連接。
隨后,團隊引入了順序 - 并行二象性概念,展示了超連接如何動態(tài)優(yōu)化層的排列以提升網(wǎng)絡(luò)性能。
殘差連接是不可訓練的超連接
前歸一化和后歸一化的殘差連接可以表示為以下擴展率為 的超連接矩陣:
其中,和
分別表示神經(jīng)網(wǎng)絡(luò)層輸入和輸出的標準差,
表示它們之間的協(xié)方差。
對于 Pre-Norm,其超連接矩陣是一個 的矩陣,右下三角部分填充為 1,其余部分為占位符 0。對于 Post-Norm,權(quán)重依賴于輸入和輸出的方差及協(xié)方差,形成一個
的矩陣。因此,它們的超連接矩陣是不可訓練的。
而本工作提出的方法的超連接矩陣是 矩陣,且權(quán)重是可訓練的,甚至可以基于輸入進行動態(tài)預測。
順序 - 并行二象性
給定一系列神經(jīng)網(wǎng)絡(luò)模塊,我們可以將它們順序排列或并行排列。作者認為,超連接可以學習如何將這些層重新排列,形成順序和并行配置的混合。
在不失一般性的情況下,可以將擴展率設(shè)置為 n=2。如果超連接以如下矩陣形式學習,神經(jīng)網(wǎng)絡(luò)將被順序排列:
在這種情況下,深度連接退化為殘差連接,如圖 (a) 所示。
當奇數(shù)層和偶數(shù)層的超連接矩陣分別定義為以下形式時,神經(jīng)網(wǎng)絡(luò)每兩層將被并行排列,類似于 Transformer 中的 parallel transformer block 的排列方式,如圖 (b) 所示。
因此,通過學習不同形式的超連接矩陣,網(wǎng)絡(luò)層的排列可以超越傳統(tǒng)的順序和并行配置,形成軟混合甚至動態(tài)排列。對于靜態(tài)超連接,網(wǎng)絡(luò)中的層排列在訓練后保持固定;而對于動態(tài)超連接,排列可以根據(jù)每個輸入動態(tài)調(diào)整。
4. 實驗結(jié)果
實驗主要集中在大規(guī)模語言模型的預訓練上,涵蓋了 Dense 模型和 MoE 模型。
實驗結(jié)果表明,使用超連接的模型顯著優(yōu)于使用殘差連接的模型。
1B Dense 模型實驗
只要擴展率 > 1,效果就十分顯著,且訓練更穩(wěn)定,消掉了訓練 loss 的 spikes。
7B Dense 模型實驗
團隊甚至 Scale 到了 7B 模型,效果也十分亮眼,同時可以看到有超連接的網(wǎng)絡(luò)訓練更穩(wěn)定。
7B 候選激活 1.3B 的 MoE 模型實驗
可以看到,下游指標全漲,在 ARC-Challenge 上甚至漲了 6 個百分點。
綜上,研究團隊介紹了超連接(Hyper-Connections),它解決了殘差連接在梯度消失和表示崩潰之間的權(quán)衡問題。實驗結(jié)果表明,超連接在大規(guī)模語言模型的預訓練以及視覺任務(wù)中都表現(xiàn)出顯著的性能提升。
值得注意的是,超連接的引入幾乎不增加額外的計算開銷或參數(shù)量,團隊認為,該成果具有廣泛的應(yīng)用潛力,可以推廣到文音視圖模態(tài)的不同任務(wù)上,包括多模態(tài)理解、生成基座模型等。
5. 寫在最后
團隊關(guān)注底層問題,尤其在 LLMs 和多模態(tài)方面,期望實現(xiàn)更多突破。
更多團隊技術(shù)研究進展,可以進入「豆包大模型團隊」技術(shù)解讀欄目了解。