簡化版Transformer來了,網友:年度論文
Transformer 架構可以說是近期深度學習領域許多成功案例背后的主力軍。構建深度 Transformer 架構的一種簡單方法是將多個相同的 Transformer 「塊」(block)依次堆疊起來,但每個「塊」都比較復雜,由許多不同的組件組成,需要以特定的排列組合才能實現良好的性能。
自從 2017 年 Transformer 架構誕生以來,研究者們基于其推出了大量衍生研究,但幾乎沒有改動過 Transformer 「塊」。
那么問題來了,標準 Transformer 塊是否可以簡化?
在最近的一篇論文中,來自 ETH Zurich 的研究者討論了如何在不影響收斂特性和下游任務性能的情況下簡化 LLM 所必需的標準 Transformer 塊。基于信號傳播理論和經驗證據,他們發現可以移除一些部分,比如殘差連接、歸一化層(LayerNorm)、投影和值參數以及 MLP 序列化子塊(有利于并行布局),以簡化類似 GPT 的解碼器架構以及編碼器式 BERT 模型。
對于每個涉及的組件,研究者都探討了是否可以在不降低訓練速度的情況下將其移除(包括每次更新步驟和運行時間),以及為此需要 Transformer 塊進行哪些架構修改。
論文鏈接:https://arxiv.org/pdf/2311.01906.pdf
Lightning AI 創始人、機器學習研究者 Sebastian Raschka 將這項研究稱為自己的「年度最愛論文之一」:
但也有研究者質疑:「這很難評,除非我看過完整的訓練過程。如果沒有歸一化層,也沒有殘差連接,如何能在大于 1 億參數的網絡中進行擴展?」
Sebastian Raschka 表示贊同:「是的,他們試驗的架構相對較小,這是否能推廣到數十億參數的 Transformer 上還有待觀察?!沟匀槐硎具@項工作令人印象深刻,并認為成功移除殘差連接是完全合理的(考慮到其初始化方案)。
對此,圖靈獎得主 Yann LeCun 的評價是:「我們僅僅觸及了深度學習架構領域的皮毛。這是一個高維空間,因此體積幾乎完全包含在表面中,但我們只觸及了表面的一小部分。」
為什么需要簡化 Transformer 塊?
研究者表示,在不影響訓練速度的前提下簡化 Transformer 塊是一個有趣的研究問題。
首先,現代神經網絡架構設計復雜,包含許多組件,而這些不同組件在神經網絡訓練動態中所扮演的角色,以及它們之間如何相互作用,人們對此尚不清楚。這個問題事關深度學習理論與實踐之間存在的差距,因此非常重要。
信號傳播理論(Signal propagation)已被證明具有影響力,因為它能夠激勵深度神經網絡架構中的實際設計選擇。信號傳播研究了初始化時神經網絡中幾何信息的演化,通過跨輸入的分層表征的內積來捕捉,在訓練深度神經網絡方面取得了許多令人印象深刻的成果。
然而,目前該理論只考慮初始化時的模型,而且往往只考慮初始前向傳遞,因此無法揭示深度神經網絡訓練動態的許多復雜問題,例如殘差連接對訓練速度的助益。雖然信號傳播對修改動機至關重要,但研究者表示,他們不能僅從理論上就得出簡化的 Transformer 模塊,還要依靠經驗見解。
在實際應用方面,考慮到目前訓練和部署大型 Transformer 模型的高昂成本,Transformer 架構的訓練和推理流水線的任何效率提升都代表著巨大的潛在節約意義。如果能夠通過移除非必要組件來簡化 Transformer 模塊,既能減少參數數量,又能提高模型的吞吐量。
這篇論文也提到,移除殘差連接、值參數、投影參數和序列化子塊之后,可以同時做到在訓練速度和下游任務性能方面與標準 Transformer 相匹配。最終,研究者將參數量減少了 16%,并觀察到訓練和推理時間的吞吐量增加了 16%。
如何簡化 Transformer 塊?
研究者結合信號傳播理論和經驗觀察,介紹了如何從 Pre-LN 模塊出發,生成最簡單的 Transformer 塊(如下圖)。
在論文第四章的每一個小節,作者分別介紹了如何在不影響訓練速度的情況下每次刪除一個塊組件。
這一部分的所有實驗都在 CodeParrot 數據集上使用了一個 18-block 768-width 的因果僅解碼器類 GPT 模型,這個數據集足夠大,因此當作者處于單個訓練 epoch 模式時,泛化差距非常?。ㄒ妶D 2),這使得他們可以專注于訓練速度。
刪除殘差連接
研究者首先考慮刪除注意力子塊中的殘差連接。在公式(1)的符號中,這相當于將 α_SA 固定為 0。簡單地移除注意力殘差連接會導致信號退化,即秩崩潰(rank collapse),從而導致可訓練性差。在論文 4.1 部分,研究者詳細解釋了他們的方法。
刪除投影 / 值參數
從圖 3 中可以得出結論,完全移除值和投影參數 W^V、W^P 是可能的,而且每次更新的訓練速度損失最小。也就是說,當 β_V = β_P = 0 和 identity 初始化的
時,在相同的訓練步數后,本研究基本上能達到 Pre-LN 塊的性能。在這種情況下,在整個訓練過程中都有 W^V = W^P = I,即值和投影參數是一致的。作者在 4.2 節介紹了詳細方法。
刪除 MLP 子塊殘差連接
與上述幾個模塊相比,刪除 MLP 子塊殘差連接要更具挑戰性。與之前的研究一樣,作者發現,在使用 Adam 時,如果沒有 MLP 殘差連接,通過信號傳播使激活更加線性仍會導致每次更新訓練速度的顯著下降,如圖 22 所示。
他們還嘗試了 Looks Linear 初始化的各種變體,包括高斯權重、正交權重或恒等權重,但都無濟于事。因此,他們在整個工作中使用標準激活(例如 ReLU)和 MLP 子塊中的初始化。
作者轉向并行 MHA 和 MLP 子塊的概念,這在幾個近期的大型 transformer 模型中已被證明很受歡迎,例如 PALM 和 ViT-22B。并行 transformer 塊如下圖所示。
作者在論文 4.3 節詳細介紹了移除 MLP 子塊殘差連接的具體操作。
刪除歸一化層
最后一個被刪除的是歸一化層,這樣就得到了圖 1 右上角的最簡塊。從信號傳播初始化的角度來看,作者可以在本節簡化的任何階段移除歸一化層。他們的想法是,Pre-LN 塊中的歸一化會隱式地降低殘差分支的權重,而這種有利的效果可以通過另一種機制在沒有歸一化層的情況下復制:要么在使用殘差連接時明確降低殘差分支的權重,要么將注意力矩陣偏向 identity / 將 MLP 非線性轉化為「更」線性。
由于作者在修改過程中考慮到了這些機制(如降低 MLP β_FF 和 Shaped Attention 的權重),因此無需進行歸一化處理。作者在第 4.4 節介紹了更多信息。
實驗結果
深度擴展
鑒于信號傳播理論通常關注很大的深度,而這種情況下通常會出現信號退化。因此一個很自然的問題就是,本文的簡化 transformer 塊所提高的訓練速度是否也能擴展到更大的深度?
從圖 6 中可以觀察到,將深度從 18 個塊擴展到 72 個塊后,本研究的模型和 Pre-LN transformer 的性能都得到了提高,這表明本研究中的簡化模型不僅訓練速度更快,而且還能利用更大的深度所提供的額外能力。事實上,在使用歸一化時,本研究中的簡化塊和 Pre-LN 的每次更新軌跡在不同深度下幾乎沒有區別。
BERT
接下來,作者展示了他們的簡化塊性能除了適用于自回歸解碼器之外,還適用于不同的數據集和架構,以及下游任務。他們選擇了雙向僅編碼器 BERT 模型的流行設置,用于掩蔽語言建模,并采用下游 GLUE 基準。
如圖 7 所示,在 24 小時運行時內,與(Crammed)Pre-LN 基線相比,本研究的簡化塊可以媲美掩蔽語言建模任務的預訓練速度。另一方面,在不修改值和投影的情況下刪除殘差連接再次導致訓練速度的顯著下降。在圖 24 中,作者提供了 microbatch 步驟的等效圖。
此外,在表 1 中,研究者發現他們的方法在 GLUE 基準上經過微調后,性能與 Crammed BERT 基準相當。
他們在表 2 中對下游任務進行了細分。為了進行公平比較,他們使用了與 Geiping & Goldstein (2023) 相同的微調協議(5 個 epoch、各任務超參數恒定、dropout regularisation)。
效率提升
在表 1 中,研究者還詳細列出了使用不同 Transformer 塊的模型在掩蔽語言建模任務中的參數數量和訓練速度。他們以預訓練 24 小時內所采取的 microbatch 步驟數與基線 Pre-LN Crammed BERT 的比率計算了速度。結論是,模型使用的參數減少了 16%,SAS-P 和 SAS 的每次迭代速度分別比 Pre-LN 塊快 16% 和 9%。
可以注意到,在這里的實現中,并行塊只比 Pre-LN 塊快 5%,而 Chowdhery et al.(2022 )觀察到的訓練速度則快 15%,這表明通過更優化的實現,整個訓練速度有可能進一步提高。與 Geiping & Goldstein(2023 年)一樣,此處實現也使用了 PyTorch 中的自動算子融合技術 (Sarofeen et al., 2022)。
更長的訓練
最后,考慮到當前在更多數據上長時間訓練較小模型的趨勢,研究者討論了簡化塊在長時間訓練后是否仍能達到 Pre-LN 塊的訓練速度。為此,他們在 CodeParrot 上使用圖 5 中的模型,并使用 3 倍 token 進行訓練。準確地說,是在批大小為 128、序列長度為 128 的情況下進行了約 120K 步(而不是 40K 步)的訓練,這將導致約 2B 個 token。
從圖 8 可以看出,當使用更多的 token 進行訓練時,簡化的 SAS 和 SAS-P 代碼塊的訓練速度仍然與 PreLN 代碼塊相當,甚至優于 PreLN 代碼塊。
更多研究細節,可參考原論文。