無比喻,不論文!用「畫家流水線」的方式理解Transformer中間層
盡管Transformer架構已經主宰了當今幾乎所有的大模型,但我們依舊對它的工作原理知之甚少。
而且,基于Transformer的預訓練LLM動輒有幾十億參數,很難直接對模型進行可解釋性分析。
同時,模型中間層由N個相同的塊堆疊在一起,它們之間唯一的區別只有層次位置和權重值,這就讓理解中間層更加困難。
然而,最近發表的一篇論文卻給出了一個十分通俗易懂的比喻——「畫家流水線」。
論文地址:https://arxiv.org/pdf/2407.09298v1
有著「東京AI夢之隊」之稱的Sakana AI,聯合IBM前AI負責人Satya Nitta創始的Emergence AI,兩個團隊的研究人員用一種新的「打開方式」來解釋Transformer架構的中間層。
值得一提的是,這篇論文作者之一Llion Jones同樣也是當年Transformer架構的共同創建者之一。
那么,「畫家流水線」這個比喻該如何理解呢?
首先,輸入被看作是一張畫布,輸入通過N個組成中間層的塊的過程,就像是畫布在「畫家流水線」上進行傳遞的過程。
有些畫家擅長畫鳥,而有些畫家則更擅長畫魚。每個畫家從前面的畫家手中接過畫布,然后決定是在畫上添幾筆,還是直接傳給后面的畫家。
在這個類比中,非常重要的一點是,每個畫家都使用相同的「詞匯」來理解畫作,因此一個畫家可以在流水線上從前一個畫家手中接過畫作,但不會因為對畫面理解不同而造成災難。
畫家們也可以重新排序(調整圖層的前后順序),甚至可以同時添加筆觸,就像N個塊可以并行運行。
這個類比并不是一個嚴謹的理論,但可以提供一個幫助我們思考Transformer層的有趣視角。
在這個類比的啟發下,研究人員提出了一些假設,并通過實驗來驗證這些假設是否成立——
- 不同層是否使用相同的表征空間?
- 所有的層都是有必要的嗎?
- 中間層是否都在執行相同的功能?
- 層的順序重要嗎?
- 我們能并行運行各層嗎?
- 順序是否對與某些特定任務而言更重要
- 循環是否有助于并行層?
- 哪些變體對性能的損害最小?
實驗
主要用于實驗包括兩種預訓練LLM,分別是decoder-only架構的Llama2-7B,以及encoder-only架構的BERT。Llama2-7B有70億個參數和32層(每層含2.02億個參數),BERT僅有24層和3.4億個參數。
在下述所有實驗過程中,模型都是凍結的。除了對BERT進行GLUE基準測試時進行了標準的微調步驟,參數沒有經過任何修改。
評估過程采用了ARC(科學考試題)、HellaSwag(常識)、GSM8K(數學應用題)、LAMBADA(單詞預測)等常用基準。
其中LAMBADA任務可以衡量模型困惑度(perplexity),任務最接近預訓練時的原始token預測。
結果發現,Transformer的中間層有一定程度的一致性,但不冗余,而且對數學、推理任務而言,各層的運行順序比在語義任務中有更重要的影響。
各層「說同一種語言」?
Transformer中的不同層是否共享相同的表示空間?
為了回答這個問題,論文采用的方法是讓模型跳過特定層或調換相鄰層的順序,觀察會不會出現災難性后果。
圖2中展示了Llama 2 7B在跳過或調換一些層后,模型整體在Open-LAMADA基準上的表現。
可以看到,除了起始和末端的幾層,模型對這兩種架構修改都表現出了相當強的魯棒性。
因此可以得出初步結論:1)中間層共享同一個表示空間,2)表示空間與「外層」(第一層和最后幾層)不同。
為了進一步驗證,論文還進入模型內部,測量了不同層中隱藏狀態內激活函數的余弦相似度(圖3),表明這種一致性在三個模型的所有中間層都成立。
上圖還可以很清晰看到,模型各層自然形成了4~5個不同的相似組,比如Llama 2 13B模型中分別是:第0層,1-3層、中間層,以及最后的1層或2層。
據此,Transformer中的所有層可以被大致分為三類:起始層、中間層和結束層。
此外,圖3中的矩陣也能和圖2中的模型分數相對應,更能有力證明,中間層之間共享語義表達空間。
所有層都必要?
為了進一步檢驗中間層的重定向空間是否真正共享(除了具有接近的余弦相似性),研究人員嘗試跳過多個層。
也就是說,將第N層的輸出直接送入第N+M層的輸入(其中M>1),從而「跳過」M-1層。
在不進行任何微調的情況下,這個實驗是要看看N+M層能否理解來自N層的激活,盡管它在訓練中只接受了來自N+M-1層的輸入。
結果顯示,Llama2-7B和BERT-Large的許多基準性能都出現了一定程度的下降。
那么,所有層都有必要嗎?這一問題已經有了答案。
No! 并非所有層都是必要的,至少有幾個中間層可以跳過,而不會發生災難性故障。
左圖:Llama2-7B跳過N層~32-N層的基準測試結果(歸一化);右圖:BERT跳過N層~24-N 層的基準測試結果(未歸一化)
中間層功能相同嗎?
如果中間層共享一個共同的表征空間,這是否意味著這些層是多余的呢?
為了驗證這一點,研究人員重新進行了上一小節的「跳過」實驗。
但不同的是,這次不是直接跳過M個中間層,而是用模型最中心的的一層代替全部M個層(Llama是第16層,BERT是第12層),相當于在這一層上循環T-2N+1次,其中T是層的總數。
結果表明,隨著被替換層數M的增加,基準測試結果迅速下降。
在研究人員所嘗試的所有測試中,這一項測試的變化是最嚴重的,比直接跳過一些層還要嚴重得多。
因此,中間層功能相同嗎?這一問題的答案是——
No! 在中間層之間共享權重是災難性的,這表明中間層在執行不同的功能。
用中心層替換M個中間層(左側經過歸一化,右側未經歸一化)
順序重要嗎?
之前的實驗表明,中間層共享一個表征空間,但對這個空間執行不同的操作。
那么另一個問題來了——這些操作的執行順序有多重要?
論文進行了兩組實驗來檢驗這個問題。首先,以與預訓練完全相反的順序運行中間層,如下圖所示:
第二組則是以隨機順序運行中間層,最終結果是取10個隨機種子進行實驗后的均值。
圖6和圖7分別展示了中間層完全翻轉和隨機順序的結果,雖然都出現了一定程度的性能下降,但兩者的結果都優于直接跳過的情況。
所以,中間層順序重要嗎?這一問題的答案是——
比較重要。改變中間層的執行順序,無論是隨機打亂或者完全翻轉,都會導致模型性能退化。
并行運行
如果層本身的存在比它們的執行順序更重要,那么我們是否可以獨立運行各層,最后合并它們的結果呢?
比如像下圖中,將原本堆疊在一起的中間層展開,并行運行后取各層輸出的平均值,傳遞給最后的N個層。
實驗結果顯示,GSM8K(數學應用題)基準中,模型性能有劇烈的變化,直線下降,其他基準分數的下滑則平緩得多。
我們暫且可以下這樣一個結論:并行運行是可行的,但解決數學問題除外。
要理解這種性能下降,可以用我們的「畫家流水線」進行類比:某些中間層只有在看到合適輸入時,才能對結果有所貢獻,就像一個擅長畫車輪的畫家,只有在畫面上看到汽車車身時,才更有可能畫出輪子。
如果是這種情況,將中間層并行運行的過程迭代多次應該會提高性能。
如下圖所示,論文將多個并行層的平均輸出再作為輸入反饋回去,如此進行一定次數的循環。
圖9顯示了循環3次的結果,與圖8中沒有循環的方案相比,性能曲線的確相對平緩,尤其是在圖右BERT模型未經歸一化的分數上更加明顯。
圖10更清楚直觀地展示了,并行的中間層數和循環次數如何影響性能,其中紅框圈出了每列上的最高值。
除了29層和31層(接近Llama 2 7B的總層數32)得出例外的結果,從5層到27層都呈現出一致的趨勢:最佳迭代次數大致與并行化層數呈線性比例。
實驗結果總結
將上述所有實驗結果放到同一張圖中(圖11),我們就能比較不同變體對模型性能的影響程度。
左圖(Llama2)取各基準的中值,右圖(BERT)取各基準的平均值
「隨機化層順序」和「循環并行」分別在Llama2和BERT-Large上造成了最少的性能下降,「中間重復」方案(用中心層運行多次代替整個中間層)則在兩個模型上都造成了最嚴重的滑坡。
討論
自從Transformer發布后,大多數工作都在關注架構的修改和優化,以達到性能提升或參數減少。這篇論文則提供了另一種視角,調查了層并行化和重用的影響。
基于「Transformer層即畫家」這個類比,我們開頭提出的幾個問題都通過實驗得到了答案,最后得到了3個有趣的發現:
- 所有Transformer層可以大致分為三類:起始層、中間層和結束層,其中中間層占比最大;
- 中間層具有一定程度的一致性,但并不冗余;
- 與語義任務相比,各層的執行順序對數學和推理任務更為重要。
為什么Transformer架構面對各種架構修改時能表現出如此強大的魯棒性?作者表示將在之后的工作中再深入研究。
一個可能的假設是,訓練過程中的殘差連接是各層共享相同表征的必要條件。
我們已經知道,殘差連接有助于解決梯度消失問題,然而相比沒有殘差連接的Transformer,加上殘差會降低性能。
如果能在沒有殘差的Transformer上重新運行上述架構的變體,看看是否會破壞完全無殘差模型所取得的微薄收益,那將會非常有趣。
對于未來的其他工作,研究人員還計劃「解凍」模型,并研究Transformer是否需要(以及需要多長時間)通過微調來適應上述的架構變化。
雖然本文的目的是更好地理解Transformer的中間層,而非引入新模型,但根據實驗結果,中間層并行或者干脆跳過都可以用適度的準確性損失換取更低的推理延遲。
作者團隊
本文作者分別來自兩家AI初創公司:Sakana AI和Emergence AI。
Sakana AI在今年年初剛剛獲得3000萬美元的種子輪融資,由Lux Capital領投,并得到了硅谷頂級風投公司Khosla Ventures以及Jeaf Dean、Alexandr Wang等大佬的支持。
公司研發的重點是基于自然啟發的新型基礎模型,創始團隊也是星光熠熠,一半成員來自「AI黃埔軍校」——谷歌大腦和DeepMind。
相比于關注基礎研究的Sakana,Emergence AI更關注應用,專門從事LLM驅動的multi-agent系統研發。
公司聯合創始Satya Nitta曾擔任IBM研究院「AI解決方案」領域的全球主管,其中的許多研究人員和工程師也同樣來自谷歌、Meta、微軟、亞馬遜和Allen AI等頂尖機構。
Emergence上個月剛剛從Learn Capital獲得9720萬美元的資金,以及額外的總計超過一億美元的信貸額度,未來的發展也是前途可期。