405B大模型也能線性化!斯坦福MIT最新研究,0.2%訓練量讓線性注意力提分20+
生產級大模型應用線性注意力的方法,來了。
線性Attention(包括RNN系列),再也不用困在幾B參數的范圍內娛樂了。
一套方法,即可線性化現有各種量級的Transformer模型,上至Llama 3.1 405B,也只需要十來張顯卡在兩天內搞定!
這就是斯坦福、MIT等科研機構推出的低秩線性轉換LoLCATs(Low-rank Linear Conversion with Attention Transfer)。
論文與代碼:https://github.com/HazyResearch/lolcats
應用LoLCATs,可以實現傳統注意力(softmax)到線性注意力的無縫轉移,
且轉換后僅需開銷很低的微調(LoRA),0.2%的參數更新即可恢復精度,對比同類的線性注意力模型或方法, 5-shot MMLU直接提高了20分左右!
也就是說,在幾乎不損失Transformer大模型語言能力的基礎上,將LLM的計算復雜度從二次方降到了線性。
線性Attention一事,前人之述備矣,然則,能夠真正做大做強,還是第一次。
尤其具有實用價值的是,LoLCATs實現了極小的開銷和接近原始模型的性能。
LoLCATs的線性化轉換只需兩個步驟:
首先使用線性Attention的形式替換原始Attention部分,并利用簡單的MSE損失訓練新增的參數,以近似softmax注意力;
然后通過低成本的微調(LoRA)來進一步提高模型的精度。
為了實現可擴展性,作者采用更精細的「block by block」訓練,將LLM的每k層看成一個block,盡在塊內聯合訓練注意力,以提高分層注意力匹配。
就如上圖所表示的那樣,一個羊駝(Llama)可以看成多個小刺猬疊在一起,每個小刺猬擁有獨特的用于線性化的參數,并且相互之間可以獨立訓練。
LoLCATS 加速 LLM
為了避免昂貴的訓練成本,研究者們一直在不斷探索兩個方面:
make models fast 與 create fast models
諸如Mamba、RWKV、TransNormer、Hawk、 Griffin和 StripedHyena等高效的subquadratic models不斷出現,
而關于將流行的LLM線性化的工作也讓我們眼前一亮。
但是線性化LLM往往伴隨著模型質量的顯著降低,你甚至能通過MMLU的測試分數猜出一個模型是不是傳統的Attention架構,或者傳統Attention塊在模型中的占比。
另外,從實用的角度講,只有拿下了生產級別的大模型,線性化的道路才能真正與傳統Transformer平分秋色。
預備知識
先打基礎:為什么要線性化?
正常的softmax注意力可以表示為下圖上面的公式:
由于softmax的緣故,只能先算Q乘K,導致中間緩存和計算量隨序列長度的平方增長;
線性化就是設計倆函數來近似softmax,從而把公式轉化成下面的形式。
此時Q和K不需要綁在一起了,就可以先算K乘V,這個順序的改變導致中間緩存和計算量隨向量長度的平方增長,而相對于序列長度是線性關系。
這就是線性化的意思,這樣的Attention也就不懼怕長序列帶來的壓力了。
開始線性化
本文中,作者的主要想法是向線性化Transformer中添加三個簡單的概念:
1. Learnable (Linear) Attentions:可學習的(線性)注意力
2. Low-rank Adaptation:低秩適配
3. Layer-wise Optimization:分層優化
Learnable Attentions
首先訓練線性注意力來模擬和替換softmax注意力。這種「注意力轉移」的靈感來自作者之前的一篇工作:Hedgehog。
論文地址:https://arxiv.org/pdf/2402.04347
如何設計設計精妙復雜的函數來近似softmax注意力?
作者表示:與其讓人類煞費苦心,不如交給AI自己去學!
相比于Hedgehog中只使用可學習的線性注意力,作者在LoLCATs中,將其推廣為可學習的線性注意力和 + 滑動窗口。
研究人員將線性和softmax注意力統一在一個層中,訓練一些新增的參數以從整體上近似softmax注意力。
對于N個token的序列,前W個token用于計算softmax注意力,后N-W個token用于計算線性注意力,然后將這些值組合。
在Hedgehog中,作者通過KL散度來訓練特征圖以匹配注意力權重,而本文改為在注意力層的輸出上使用MSE 失。
這繞過了Hedgehog的一個限制:需要將所有注意力權重實例化為監督目標。
相反,LoLCATs可以使用FlashAttention來計算softmax注意力輸出,并將線性化注意力的內存消耗保持在O(N)。
只需將這些特征圖插入到每個現有的注意力中,即可創建線性化的 LLM。凍結所有其他權重,只訓練這些特征圖,對于7B的LLM來說,只需要調整0.2%的參數。
Low-rank Adaptation
之前的線性化工作,通常需要一個比較昂貴的端到端訓練階段。
但在LoLCATs這里,可以通過簡單地將低秩適應(LoRA)應用于注意力的QKVO權重來恢復模型的性能。
凍結所有其他內容,只訓練LoRA權重,在某些自然語言數據上,最大限度地減少LLM輸出的next-token預測損失。
Layer-wise Optimization
大多數情況下,只需要以上兩步就搞定了。但對于像Llama 3.1 405B這種規模的模型來說,還需要努力一下。
通過簡單地聯合優化所有層,可以成功地線性化7B到70B參數范圍的LLM,但整體訓練時,后面層的MSE會比前面的層更大。
當模型變得更大更深時,MSE升級為了微調Llama 3.1 405B的真正問題。
為此,研究人員使用了更精細的逐塊訓練,將Llama 3.1 405B分成多個k層塊,并僅在每個塊內聯合訓練注意力。
當使用一些線性化數據并行訓練所有模塊時,只需為每個塊預先計算LLM的隱藏狀態。
可以調節k來平衡并行訓練的速度與預計算的內存,并將隱藏狀態保存到磁盤。不需要花哨的成本模型,對于50M token的線性化來說:
k = 1時,需要2字節 × 126層 × 50M token × 16384(hidden size)= 200TB的磁盤空間來存儲隱藏狀態。
而k = 9時,磁盤空間的需求將減少為22TB,這時仍然能在單個GPU上并行訓練每個塊(9層)。
——后者顯然更友好一點,所以作者將Llama 3.1 405B的126層拆分為14個9層塊,在14個GPU上并行進行注意力的線性化,過程僅需5個小時。然后用LoRA將它們全部拼接在一起,就得到了最終模型。
實驗結果
質量恢復
下表給出了6個流行的LLM評估任務的結果。
與最近的一些線性化方法相比,LoLCATs顯著提高了不同任務和不同LLM的質量和訓練效率。
盡管只訓練了0.2% 的模型參數(40M token),LoLCATs將線性化與原始模型的性能差距平均縮小了80%以上,token to model的效率提高了500~2500倍。
在7B這個量級上,LoLCATs優于所有的線性注意力(包括RNN系列)模型:Mamba、RWKV、TransNormer、Hawk、 Griffin和 StripedHyena。
挑戰405B大模型
最后,作者使用LoLCATs將線性化擴展到Llama 3.1 70B和更大的405B模型。
與之前的線性化方法相比,首先是質量上的顯著改進。通過控制相同的線性 + 滑動窗口層,對于Llama 3.1 70B,在5-shot MMLU上的精度實現了39點的提升,對于Llama 3.1 405B,同樣實現了38.3分的改進。
其次是訓練效率的提高,在單個8x80GB H100上線性化Llama 3.1 70B僅需18個小時,而線性化Llama 3.1 405B所花費的時間比之前用于8B模型的方法還要少。
參考資料: