函數向量對齊技術,讓大模型持續學習不“失憶”丨ICLR 2025
LLMs為什么總是災難性遺忘?原來是功能激活在搞怪。
最近來自中國科學技術大學、香港城市大學和浙江大學的聯合研究團隊,通過對多個語言模型、任務序列和評估指標的系統分析,終于破解了LLMs的災難性遺忘之謎——
遺忘行為具有高度的模型依賴性,而導致遺忘發生的本質卻是功能激活的變化。
對此,團隊基于函數向量構建遺忘分析框架,刻畫和分析LLM內部功能的變化(其中功能表示模型對某具體任務的處理能力,如求反義詞、乘法計算),進一步證實了遺忘并非簡單地覆蓋已有函數,而是模型激活了帶偏差的新功能。
研究人員還設計了一種函數向量引導的訓練方法FVG,在微調過程中可以有效保留并對齊函數向量,并在多個持續學習數據集上顯著保護了模型的通用學習能力和上下文學習能力。
目前相關研究論文已被ICLR2025 oral接收,代碼也已在GitHub上公開。
接下來,我們一起來看看詳細細節。
大語言模型的“記憶困境”
災難性遺忘是指模型在學習新任務時,之前學到的知識被新任務的學習過程所覆蓋或干擾,導致模型在舊任務上的性能大幅下降。
例如,一個通用語言模型在學習新增的用戶指令處理后,面臨數學推理能力的顯著下降。
這種遺忘現象不僅影響模型的泛化能力,也使得模型在實際應用中難以適應動態變化的任務需求。
盡管語言模型的災難性遺忘問題受到了廣泛關注,但當前的研究多集中于通過單一訓練序列分析遺忘現象,忽略了不同任務組合對模型表現的復雜影響,同時也缺乏對遺忘內部機制的深入理解。
為此,作者首先通過實證研究探討了大語言模型在持續指令微調(模型在一系列指令微調任務上持續學習)中的遺忘現象,重點考察任務類型、訓練階段以及不同模型之間的差異。
作者使用SuperNI數據集[1]構建六種任務序列,覆蓋生成任務、分類任務以及二者混合任務,并且關注三個指標量化模型對不同能力的遺忘程度:
- GP(General Performance):通用任務的零樣本性能下降。
- IP(In-context Performance):通用任務的上下文學習性能下降。
- FP(Final Performance):訓練任務的性能性能下降。
得到以下結論:
- 無論是通用任務、新任務,還是上下文能力,都出現不同程度的遺忘。
- 任務類型影響遺忘程度:生成任務序列導致的遺忘顯著高于分類任務。
- 訓練階段遺忘可逆:訓練初期可能出現性能下降,但后期有明顯恢復趨勢,表明模型可能逐漸恢復部分遺忘能力。
- 模型差異顯著:遺忘現象受模型結構與預訓練數據影響。
既然模型在不同任務和不同模型中呈現出不同的遺忘現象,導致遺忘發生的本質現象究竟是什么呢?
揭示模型內部函數的遺忘本質
函數向量
作者為了解釋模型發生遺忘時的內部機理,引入了函數向量(Function Vectors, FVs)[2]這一工具。
函數向量是一種定位和表征LLM內部處理具體任務能力的方法,其采用activation patching方法對上下文學習過程中的隱狀態進行干預,識別在任務執行中起因果作用的注意力頭集合。
函數向量通過在這些注意力頭的平均激活值上求和得到。
具體而言,對于一個給定任務的數據集,函數向量的提取分為兩個步驟:
1、因果注意力頭識別
首先對模型的注意力頭進行干預,使用標簽打亂的提示(counterfactual prompt)與原始輸入組成反事實輸入,通常這會導致預測錯誤。
然后將反事實輸入在某注意力頭的表示替換為真實任務的平均激活,并計算該替換對預測結果的因果影響:
其中,表示層、頭在任務上最后一個token的平均激活。
而CE越高,表明該頭對任務表現越關鍵。
最終選擇CE值前10的注意力頭構成集合。
2、函數向量的構建
將集合中的所有注意力頭的平均激活向量求和,得到函數向量:
通過分析函數向量,研究人員發現,災難性遺忘并非是因為模型的任務處理能力在訓練過程中被破壞,而是由于模型在輸入到激活對應任務功能過程中的偏差所導致的。
換句話說,模型并沒有忘記之前學到的任務處理能力,而是這些能力未被正確激活,反而被新引入的能力所掩蓋。
作者采用這種能夠反映模型在處理特定任務上的功能特性的方式,追蹤遺忘現象發生時模型內部功能的變化。
函數向量與遺忘的關系
通過實驗分析,作者發現函數向量的變化與模型的遺忘現象之間存在顯著的強相關性。
具體來說,記為測試任務,為任務在初始模型下的函數向量表示,則表示在訓練完第個任務后任務的函數向量表示。
當函數向量與的相似度較低時,模型在測試任務 上的性能下降較嚴重。
當函數向量與的相似度較高時,模型在測試任務 上的性能下降則不明顯。
具體而言,在訓練NI-Seq-G1數據時Hellaswag的函數向量的相似度與模型性能之間的相關系數(R2值)可以達到0.873。
作者也收集了模型在不同訓練序列,不同seed下的40個checkpoint,并統計了多個測試任務在這些模型下的函數向量相似度與具體性能,可視化結果如下圖:
圖中顯示,當任務學習后的函數向量(FV)相似度較高時,模型的遺忘現象相對較輕,兩者之間存在較強的相關性。
相比之下,Last hidden state的相似度和參數變化前后的L2距離并沒有呈現出這種規律。
模型遺忘的本質
作者基于此方法研究函數向量在任務切換前后的變化,并用作揭示災難性遺忘根源的分析工具,該方法強調遺忘主要源于模型激活偏差的新功能,而非覆蓋舊功能。
作者首先依照潛變量模型(Latent Variable Model)的假設將大語言模型重新表述,具體如下:
LLM 的輸出概率被分解為對所有可能內部功能的積分:
- :在給定任務功能下的輸出概率(即執行某個特定任務功能)
- :在輸入條件下激活該功能的概率(即功能激活機制)
而在函數向量的幫助下,我們可以獲得功能的具體表達形式,得到以下公式:
具體功能被表示為一組隱狀態組合,其中索引來自集合,是激活權重,這個組合決定了處理當前任務功能的具體數值表示。
作者發現函數向量的偏移(即的變化)意味著模型功能激活機制的變化,而在前文函數向量的偏移也與遺忘強相關。
故這些現象共同支撐了一個中心論點:遺忘并非因為模型改寫了執行舊任務的功能,而是因為輸入激活機制發生了偏移,從而未能正確調用這些功能。
可以從上圖獲得更直觀的理解:通過將模型重構為潛變量模型,它被劃分為任務功能的激活和任務功能的執行兩個階段。
在學習任務1之前,模型能夠正確激活任務0的功能,從而做出正確的預測。
但在學習任務1之后,模型可能引入了一個新的函數向量,這個新的向量會對任務0輸入的激活過程造成干擾,從而導致遺忘現象的發生。
此外,作者還通過干預實驗進一步驗證了模型遺忘的內在原因。
僅通過在模型中插入被遺忘能力的函數向量或移除當前訓練任務的函數向量,研究人員就能夠顯著恢復在被遺忘任務上的能力。
函數向量引導的訓練方法
基于函數向量的分析結果,論文提出了一種新的訓練方法——函數向量引導的訓練(Function Vector Guided Training, FVG)。
這種方法的核心思想是通過正則化技術限制函數向量的變化,從而在模型學習新任務時保持其對舊任務的功能激活模式。
具體來說,FVG 方法引入了兩個新的正則化項:
1、函數向量一致性損失
通過限制函數向量的變化,確保模型在學習新任務時不會過度偏離其原有的功能激活模式,具體公式為:
其中,和分別表示在任務和任務時,模型在特定頭的激活值,是距離度量,作者采用L2距離。
2、函數向量引導的KL散度損失
通過最小化零樣本輸入與函數向量干預后的輸出之間的差異,確保模型在微調后仍能保持與原有任務函數的一致性,具體公式為:
其中,是模型在輸入上的輸出概率分布,是在函數向量干預后的輸出概率分布。
最終的優化目標是:。
其中,是語言模型的原始損失函數,和是超參數,用于平衡不同損失項的權重。
實驗驗證
作者在多個數據集和模型上進行了廣泛的實驗,驗證函數向量引導的訓練方法的有效性。
實驗結果表明,FVG方法在多個基準測試中顯著提高了模型在一般任務和上下文學習任務上的性能,同時保持了模型對新任務的學習能力。
結語
本文,作者通過引入函數向量方法,深入探討大語言模型中的災難性遺忘問題,強調了函數向量在表征與緩解遺忘現象中的關鍵作用。
作者在多個基準任務上的分析表明,模型的遺忘行為與潛在功能變量(由函數向量刻畫)發生的偏移密切相關。
基于這一發現,作者提出了一種全新的函數向量引導訓練策略,該方法結合了正則項與函數向量引導的KL散度損失函數,顯著減少了遺忘現象,從而提升了LLMs在持續學習中的通用能力與上下文學習能力。
[1] Wang, Yizhong, et al. "Super-NaturalInstructions: Generalization via Declarative Instructions on 1600+ NLP Tasks." Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing. 2022.
[2] Todd, Eric, et al. "Function Vectors in Large Language Models." The Twelfth International Conference on Learning Representations.
論文鏈接:https://arxiv.org/abs/2502.11019項目鏈接:https://github.com/GangweiJiang/FvForgetting