不是RNN的鍋!清華團隊深入分析長上下文建模中的狀態崩潰,Mamba作者點贊
與Transformer相比,RNN模型的一大優勢是應對長序列的能力。
比如Mamba,內部狀態大小始終保持不變,計算隨序列長度線性增長,吃得多,消化快。
理論雖如此,但實際情況卻是,目前的這些RNN模型在長上下文中的有效性并不能令人滿意。
為啥會這樣?空有效率但實際上能力不行?
近日,來自清華的研究團隊對此進行了深入的實驗研究:
論文地址:https://arxiv.org/pdf/2410.07145v1
文章表明,Mamba這類RNN模型在長上下文中主要面臨兩個問題:
一是無法推斷比訓練長度更長的輸入,原因是較短的訓練數據導致了循環狀態過擬合;
二是內存容量的上限,由于模型無法有效遺忘很久以前的信息,導致新的信息存不進來了。
——這倆問題明顯不是RNN的鍋。
而經過研究人員的對癥下藥,Mamba-2(370M)在256K上下文長度上達到了近乎完美的密鑰檢索精度。
所以結論就是,Mamba yes!「RNN神教」前景一片光明!
對此,Mamba的作者Albert Gu點贊轉發,并發表了相當詳細的見解:
「這是一篇很棒的論文(名字也很棒)—— 關于狀態空間模型(SSM)的狀態容量和長上下文能力的巧妙實驗。」
令人驚訝的是,對于每個狀態大小 M,當訓練上下文長度達到或超過某個臨界值 K 時,都會出現一個轉折點,在這個點上 SSM 就能夠穩健地實現長度泛化。
這是因為當上下文長度小于 K 時,循環狀態沒有被充分利用,導致模型在訓練期間會「過擬合」。但一旦通過足夠長序列的訓練使模型的狀態容量得到充分利用,它就會自動獲得泛化能力。
值得注意的是,K 與 M 竟然呈線性關系!—— 這表明每個 token 可能存在某種固有的信息含量(即存在一個值 B,使得上下文中的每個 token 對應 B 字節的循環狀態)。這個 B 值可能是由模型架構決定的?
「反過來說,過分擔心循環模型的長度泛化問題可能是一個誤區。我們無需設計新機制或特殊的緩解措施:只需要在更長的序列上訓練(因為是線性時間復雜度,所以不會增加計算開銷!),就能獲得更好的泛化效果。」
最后,Albert Gu用一句話總結:要讓你的Mamba吃得飽飽的,它就能發揮出最佳狀態!
喂飽你的Mamba
先來復習一下基礎知識。
本文以Mamba2作為主要研究對象,內部的計算表示為下圖中的并行結構:
整體的輸入輸出遵循SSM(也即RNN)的形式:
而把上圖中模塊內部所有的計算寫出來,就是下面這一坨公式:
之前提到的兩個問題,核心在于模型的內部狀態,也就是ht的表現。
所以下面在探索問題和解決方案時,咱們可以重點關注這些公式中,與ht計算相關的參數。
之前有研究表明,當上下文長度超過其訓練長度時,Mamba-1和RWKV-4的性能會嚴重下降。
順著這個思路,研究人員在兩個方向上進行了實驗分析:狀態崩潰(STATE COLLAPSE)和容量上限(STATE CAPACITY)。
狀態崩潰
狀態崩潰(SC)指的是,RNN模型在輸入上表現出異常行為的時間比訓練期間看到的時間更長的現象。
上圖展示了Mamba-2和RWKV-6在訓練長度之外的語言建模損失。為了可控性和合成任意長度的提示,這個損失是在僅由「\n」字符組成的提示上計算的(稱為「newlines」提示)。
結果表明,當上下文長度遠大于其訓練長度時,兩個RNN的性能都會嚴重下降,最后就跟瞎猜差不多了。
語言建模可能無法反映下游能力,上圖給出了Mamba-2(在8K上下文窗口上訓練)在密鑰檢索任務上的評估結果。
我們可以發現,Mamba-2在8K上下文中具有近乎完美的檢索準確性,但在序列長度超過16K后就沒法看了,無論模型參數量大小。
從上面的公式來看,這種結果可能出人意料,因為內部狀態ht的更新應該具有穩定的指數內存衰減,即對于最后k個token具有良好的檢索準確性。
問題出在哪里?
由于遞歸狀態的維度不會隨時間而變化,因此狀態崩潰期間行為的急劇變化一定是狀態值變化的結果。
作者對Mamba-2 370M中每一層的遞歸狀態進行了統計,發現當上下文長度超過訓練長度時,一些頭部的平均值和方差會急劇變化:
圖5顯示了模型第38層第2個頭的狀態,在t=20K時方差爆炸。從中可以發現這種方差爆炸在很大程度上可以歸因于少數異常通道,其余大多數通道則相對穩定。
分析一下公式,與ht計算有關的?t、Bt和xt:
如上圖所示,雖然三者都是輸入的函數,但xt相對穩定,而Bt比?t更早發生爆炸,進一步探索還能發現生成?t和Bt的卷積權重明顯更大。
作者認為,產生SC的原因是,對于訓練長度來說,狀態容量過大,模型能夠實現強大的語言建模性能,而無需學習如何忘記。
上圖顯示了第一個token在不同時間步的內存強度,作者發現爆炸的頭(第38層的第2、4、7個頭)強烈傾向于在訓練長度內保留所有信息,在t=8K時內存強度超過0.8。
解決方案
為了緩解SC,使模型沿序列長度更好地泛化,作者提出了3種解決方案,總的思想是修改狀態的update規則來避免其溢出。
Method 1: Forget More and Remember Less
通過增加狀態衰減量(忘記更多)或減少輸入信息的數量(記住更少)來減少SC,作者選擇干預Bt和αt(分別控制輸入強度和內存衰減強度)。
Method 2: State Normalization
在每次更新后對狀態進行歸一化,以確保狀態的范數始終低于閾值:
PS:這種方式會將模型轉換為非線性RNN,無法以與原始模型相同的方式并行化,預填充速度要慢得多。
Method 3: Sliding Window by State Difference
利用狀態ht可以寫為加權和的形式,來模擬滑動窗口機制,無需在每一步都從窗口的開頭重新處理。
此方法適用于所有可以寫成加權和的RNN,包括RWKV 5和6、RetNet、GLA等。盡管會使生成的計算和內存成本翻倍,但仍然是一個可以接受的權衡,因為RNN的生成成本比Transformer低很多。
以上3個是不需要訓練的方案,而基于SC是由狀態參數過擬合引起的假設,我們也可以嘗試使用超過狀態容量的序列長度來訓練模型。
容量上限
根據以上的討論,當且僅當訓練長度包含的信息少于狀態容量時,才會發生SC,所以我們可以通過實驗間接估計模型的狀態容量。
研究人員訓練了多個具有不同狀態大小和訓練長度的Mamba-2,并將SC未發生的最小訓練長度視為狀態容量。
實驗數據選擇RedPajama-V2,一個從CommonCrawl中提取的30T token的開放數據集,進行去重以確保數據質量。
在評估過程中,對長度超過16K token的文檔進行抽樣,如果不夠長,則對其進行拼接。
研究人員試驗了具有不同狀態大小的模型配置,包括來自Mamba-2官方checkpoint的三個預訓練模型,大小分別為130M、370M和780M,另外3個模型(36M、47M、85M)則從頭開始訓練。
實驗結果
上圖展示了在Mamba-2 780M上無訓練長度泛化方法的結果。我們可以看到,雖然LongMamba大大提高了模型的長度泛化性(3倍以上),但它在較短的序列上會導致明顯更大的困惑度,并且仍然不可避免地表現出SC。
相比之下,本文的所有的方法都成功地抑制了SC,使模型能夠泛化到超過64K個token。
三種方案中,狀態歸一化在較短序列上的性能大大低于其他方法,這可能是因為歸一化折疊狀態會改變heads之間的規范比率,破壞了學習機制。
上圖顯示了Mamba-2在語言建模和密鑰檢索方面的狀態容量。兩個圖中最右邊的數據點對應于Mamba-2 370M。
左邊的圖可以擬合出一個線性關系,而右邊的圖則表明Mamba-2在密鑰檢索方面的容量與狀態大小呈指數級關系。
這是因為上下文中的信息量不會隨著其長度的增加而增加。換句話說,模型存儲了恒定數量的信息,而狀態的組合數量隨著元素數量呈指數增長。