上海交大、復旦、上海 AI Lab引入漸進學習框架來驗證弱到強的推理
?一、結論寫在前面
論文標題:Weak-to-Strong Reasoning
論文鏈接:??https://arxiv.org/pdf/2407.13647??
代碼等:??https://github.com/GAIR-NLP/weak-to-strong-reasoning??
當大型語言模型 (LLMs) 超越人類水平能力時,為這些模型提供全面且準確的監督變得愈發困難。弱到強學習,即利用能力較弱的模型來解鎖更強大模型的潛在能力,在此背景下被證明是有價值的。然而,這種方法在復雜推理任務中的有效性仍未得到驗證。此外,在弱到強設置下解決推理任務目前缺乏有效方法來避免盲目模仿弱監督者及其錯誤。
本文探討了弱到強框架在復雜推理任務中的效能。論文引入了一種新方法,該方法利用弱監督激發強大能力,無需依賴人類標注或更高級模型的注釋。該方法側重于強模型自主精煉其訓練數據的能力,即使它之前未曾學習過該任務。通過迭代擴展其學習范圍,強模型不斷拓寬其推理技能。這種自我導向的數據治理對于擴大AI推理能力提升的規模至關重要,使模型在其發展軌跡中更加獨立和高效。
論文使用Llama2-70b作為強模型,測試了三個獨立的弱模型:Llama2-7b、Gemma-2b和Mistral-7b,并在常用的數學推理數據集GSM8K和MATH上進行實驗。實驗結果顯示:
1.完全弱微調雖然在分類任務中有效,但在復雜推理任務中表現不佳。
2.論文提出的方法顯著優于完全弱微調方法,在第一階段訓練(M → Mplus)后,僅由弱模型(即Gemma-2b)監督時,在GSM8K上實現了26.99點的改進,并通過偏好優化(Mplus → Mpro)進一步提高了8.49點的性能,而無需知道金標準答案。
3.論文提出的偏好優化階段使強模型能夠從弱監督者的錯誤中學習,最終在具有挑戰性的場景(如4-5級MATH問題)中超越了在金標準解決方案上微調的強模型(即強上限)。
為更準確地模擬未來場景,論文在OlympicArena上進行了額外的實驗,這是一個極具挑戰性的數據集,沒有明確的標準答案。盡管規模較小,但Llama3-8binstruct(AI@Meta,2024)已經經過對齊,并被證明可以有效地監督更大的Llama3-70b,后者的潛力尚未被充分發揮。此外,論文提出的兩階段訓練方法比完全弱微調高出3.19點。
圖1:( a ):使用 Llama2-7b 監督 Llama2-70b 在 GSM8K 上的測試準確率。(b):使用 Llama3-8b-instruct 監督 Llama3-70b 在 OlympicArena 上的測試準確率。"弱基礎" 指的是弱模型的結果。"全弱微調" 指的是基線結果,其中強模型在弱模型生成的完整數據集上進行簡單微調。"論文的階段I" 表示使用論文提出的弱到強方法進行監督微調的第一階段結果。請注意,論文的方法在階段I產生了三種增強的強模型變體,論文在這里展示最佳結果。"論文的階段II" 表示使用論文的方法進行偏好優化的第二階段結果
二、論文的簡單介紹
2.1 論文的背景
"學生不必不如老師;老師不必比學生更聰明。" ——《On Teachers》
隨著人工通用智能(AGI)研究的推進,創造超越人類認知能力的超智能系統一直是該領域的一個關鍵目標)。這一追求帶來了一系列挑戰,尤其是在這些高級AI模型的監督和學習范式方面。傳統的監督方法通常依賴于人類監督或來自更高級模型的指導(即知識蒸餾,distilled knowledge)),但當AI的能力超越其監督者時,這些方法變得不足。
為解決這個問題,論文關注弱到強學習范式(weak-tostrong learning paradigm),該范式在一個獨特的任務設置下運作,即只有一個能力較弱的模型和一個更強大但未充分利用的模型可用。弱到強學習的核心問題是,能力有限的模型是否能有效指導更先進、更強大模型的發展。Burns等人(2023)的先前研究已經證明了這種方法在分類、國際象棋和獎勵建模任務中的可行性。然而,這種設置是否適用于更復雜的推理任務仍是一個開放性問題,這些任務需要的不僅僅是簡單的外推或模式識別。
復雜推理是人類認知的一個關鍵方面,對于評估大語言模型是否能模仿或超越人類理解世界、做出決策和解決問題的能力至關重要。鑒于這些任務的復雜性和關鍵性,將弱到強學習框架應用于高級推理挑戰是至關重要的,特別是在實現超智能的更廣泛背景下。
盡管Burns等人(2023)建議,在弱模型產生的全部噪聲數據上簡單地微調強模型(稱為完全弱微調)可以持續提高其性能超過較弱的對應模型,但這種方法仍遠未恢復強模型的全部能力,而且論文的實驗表明,在面對更復雜的推理挑戰時,它失去了效果。他們還提出了一種輔助置信度損失,以緩解強模型模仿其監督者錯誤的問題。然而,這種方法是為具有一組固定標簽的分類任務量身定制的,不能自然地擴展到包括推理在內的開放式生成任務。目前,在弱到強推理框架內,除了簡單的微調之外,缺乏有效的方法來防止過度擬合弱錯誤并進一步激發強模型的內在推理能力。
為實現上述目標,論文引入了一個漸進式改進學習框架,遵循的原則是模型可以通過最初關注較小、更可靠的數據子集,然后逐步擴大學習范圍來更有效地提高其能力,如圖2所示:
?在第一階段,論文假設利用可能更準確的較小數量的數據更有利。論文通過結合弱模型生成的數據和更高級模型通過上下文學習自生成的數據來實現這一點。然后將這種混合用于有選擇地策劃后續監督微調的數據集。
?在第二階段,在開發出具有改進推理能力的強模型后,論文利用其構建對比樣本進行偏好優化的能力,使模型能夠有效地從較弱模型的錯誤中學習。
2.2 預備知識
2.2.1 大語言模型的典型學習范式
論文概述了大型模型訓練中的常見學習范式,主要特征是需要標準答案和來自更強大模型的監督,如表1所示。
通用監督學習 在訓練大語言模型時,理想情況是擁有足夠數量的帶有標準答案的訓練數據,論文稱之為通用監督學習范式。然而,獲取這樣的數據往往需要大量的標注工作,有時甚至是不可能的。因此,各種學習范式應運而生,以減少數據質量和數量的影響,同時仍能提高性能。
表1:LLMs的典型學習范式?!癡”和“X”表示是否需要監督,“-”表示可選?!癎.T:”代表真實答案
基于蒸餾的學習 在當前背景下,即使沒有標準答案,要提升像Llama2-70b這樣的強大模型,仍可以通過尋求像GPT-4這樣更強大的模型的幫助來實現改進。因此,許多現有工作建議讓一個更強大的模型充當教師模型,為目標模型提供具體反饋以改進。這種范式可以被視為蒸餾更強大教師模型的知識。然而,僅僅模仿教師模型并非長期解決方案;在模仿數據中未充分代表的任務上,模仿模型只能略微縮小與教師模型的性能差距。此外,蒸餾學習主要有益于那些不如教師模型能力強的模型。
自我改進學習 考慮到由人類或更強大的專有模型標注訓練數據的高昂成本,一系列工作依賴于模型自身生成的正確響應來更新它。例如,Zelikman等人(2022)、Yuan等人(2023)、Singh等人(2023)、Hosseini等人(2024)根據最終答案的正確性篩選解決方案,而Lightman等人(2023)、Lin等人(2024)則使用在金標準標注上訓練的獎勵模型來評分自生成內容。顯然,無論是使用二元標簽還是細粒度反饋,這種范式仍然需要標準答案來評估模型自生成響應的可用性。沒有標準答案,自我改進只能帶來最小的性能提升,甚至可能降低性能。
半監督學習 從傳統機器學習領域的半監督學習中獲得啟發,另一種大語言模型學習不依賴于大量標注,而是依賴于一個小型的高質量種子數據集。Tong等人(2024)通過學習自生成響應與專家標注響應之間的差異,展示了改進。論文還將當前流行的研究主題——易到難泛化納入這一類別,其中模型通過學習人類對較簡單任務的標注來解決復雜任務。這一系列研究不可避免地需要獲取一小部分高質量的標準答案。
弱到強學習 在模型超越人類能力的場景中,為復雜任務提供全面和精確監督的挑戰變得更加嚴峻,特別是在沒有標準答案,也沒有更高級模型提供監督指導的情況下。這種缺失凸顯了弱到強學習方法的關鍵重要性。這些方法獨特地利用較弱的監督信號來恢復已經強大的模型中的潛在知識。例如,用GPT-2級別的監督者對GPT-4進行微調,可以在某些任務上恢復接近GPT-3.5級別的性能。這一策略對推動人類社會進步具有深遠意義,它使大語言模型具備解決當前無法解決的數學和物理挑戰的能力。與其他學習范式不同,弱到強學習在相對寬松的條件下運作,為探索和創新開辟了廣闊的機會。
2.2.2 弱到強推理設置
論文在弱到強的設置下處理推理任務,如表2所示。首先,論文研究數學推理任務,如GSM8k和MATH中的任務。這些任務要求推理過程的每一步都展示基本的數學問題解決技能,包括問題理解和代數運算,并在前幾步的基礎上繼續推進。這對模型的學習和泛化能力提出了更高的要求。與分類任務不同,模型可以依賴于表面模式的外推或識別,而推理任務幾乎無法從猜測中獲益。
然后,論文使用一個具有一定數學問題解決能力的弱模型(例如Llama2-7b),記為m。這個模型類似于超智能時代中具有有限專業知識的人類監督者。此外,論文只有一組沒有標準答案的問題Q = {qi,目標是提高強模型M(例如Llama2-70b)的推理能力。
為了實現這一點,論文遵循Burns等人(2023)的方法,將原始訓練集隨機分成兩個相等的部分,Dgold,1和Dgold,2。弱模型最初使用Dgold,1進行微調,其中有可用的標準解決方案,從而得到一個具有一定問題解決能力的弱模型,即m。相比之下,強模型只能訪問來自Dgold,2的問題,沒有推理鏈或最終答案,即Q。
2.3 方法論
在本節中,論文提出了一種弱到強的訓練方法,旨在最大限度地利用弱數據并激發強模型的內在潛力。首先,在沒有標準答案和外部信號的情況下,論文識別出潛在的正樣本。在第一階段,論文僅利用這部分數據進行監督式微調。然后,一旦強模型達到了一定的推理水平,論文就在第二階段使用全部弱數據,特別是通過基于偏好學習的方法(如 DPO,)來利用潛在的負樣本,鼓勵強模型從弱模型的錯誤中學習。整個框架如圖 3 所示。
2.3.1 階段I:從“正樣本”中學習
給定一個弱模型m 和一系列沒有真實標簽的數學問題Q,m 生成弱數據D_weak = {q_i, C_weak,i, a_weak,i },其中q_i ∈ Q,C_weak,i 表示推理鏈,a_weak,i 表示最終答案。a_weak,i 的正確性是未知的。核心挑戰在于:論文如何最大化利用m 和D_weak 來充分增強和恢復一個更強模型M 的數學推理能力?
2.3.1.1 全面弱數據微調
論文的初始策略是對更強模型M 在整個弱數據集Dweak 上進行微調。盡管先前研究(Burns et al., 2023)已驗證了這種方法在文本分類任務中的有效性,但其在推理任務中的效果尚未探索。因此,論文著手研究弱到強泛化現象是否也能在此較少探討的領域增強M 的推理能力。
2.3.1.2 弱上下文學習
另一種直接的方法是上下文學習(ICL, in-context learning),它僅需要幾個訓練樣本作為提示中的演示。具體來說,論文從D_weak 中隨機選擇四個樣本作為演示。由于論文無法訪問真實標簽,這些演示不能被證明是正確的。
圖3:論文的方法概覽,從M 演進為Mplus 再到Mpro。左側:論文利用最終答案一致性來有選擇地從多樣化的來源中過濾弱數據和ICL數據,這些數據用于微調強模型M 并獲得具有增強數學推理能力的Mplus。右側:論文利用Mplus 的置信度來識別對比樣本以進行性能優化,從而得到更穩健的強模型Mpro。
2.3.1.3 弱-ICL微調
鑒于模型可以通過監督微調模仿弱錯誤,論文建議在使用前對Dweak進行精煉,而不是盲目使用所有數據。此外,論文還尋求利用通過上下文學習激活的強模型的固有能力?;谶@兩個想法,論文引入了弱-icl微調,同時使用弱數據D_weak和"icl數據"D_icl = {q_i, c_icl,i, a_icl,i},其中qi ∈ Q,c_icl,i和a_icl,i是由M通過少樣本示例生成的,作為更高質量的監督信號。需要注意的是,對于D_weak和D_icl,論文無法確定某個答案是否正確。
盡管如此,當兩個采用不同數據表示的模型在開放式任務中得出相同答案時,這表明準確性的可能性更高。這種現象支持了在不同方法之間觀察到一致性時結果的可靠性。因此,論文比較由弱模型和強模型分別生成的D_weak和D_icl,并在a_weak,i = a_icl,i時選擇D?weak和D?icl用于后續的監督微調。論文稱這種方法為最終答案一致性??紤]到這兩組數據的組合,論文可以得到三個版本的增強微調強模型:
?M_weak-ft:在D?weak上微調的M。
?M_icl-ft:在D?icl上微調的M。
?M_hybrid-ft:在D?weak和D?icl的并集上微調的M。
迭代訓練 仔細觀察M_weak-ft和M_icl-ft,論文發現它們仍然滿足具有不同數據表示的條件,因為它們是在來自不同來源的數據上訓練的——D?weak由弱模型生成,而D?icl主要源自強模型本身。因此,論文可以進行迭代訓練以提升性能。論文將初始輪次的監督微調數據表示為D?1weak和D?1icl,得到模型M1weak-ft、M1icl-ft和M1hybrid-ft。在第二次迭代中,論文將M1weak-ft應用于Q以構建D2weak,將M1icl-ft應用于構建D2icl。這里,下標"weak"和"icl"表示初始數據來源。然后論文應用最終答案一致性來獲得D?2weak和D?2icl。經過另一輪監督微調后,論文得到:
?M2weak-ft:在D?2weak上微調的M。
?M2icl-ft:在D?2icl上微調的M。
?M2hybrid-ft:在D?2weak和D?2icl的并集上微調的M。
需要注意的是,迭代訓練步驟是可選的;當數據質量過低或模型過擬合時,可能會導致性能下降。
2.3.2 第三階段:從“負面”樣本中學習
論文將第一階段的最終迭代模型表示為 Mplus,該模型已學習了雙重數學解決方案,并具有進一步增強的潛力。接下來,論文應用偏好優化技術,戰略性地利用由m 生成的原始弱數據集Dweak={q_i, c_weak, a_weak,i}中的潛在錯誤子集,使得強模型能夠識別并避免在未來的推理過程中出現類似的錯誤。關鍵在于如何構建用于學習的對比樣本。
在沒有訪問真實答案的情況下,當前具備增強推理能力的強大模型會根據其置信度識別最可能正確的答案。具體而言,對于每個問題q_i 屬于 Q,論文從模型Mplus 中抽取n 個回答,并將這些回答中出現頻率最高的答案的概率定義為置信度。當置信度低于指定閾值τ 時,論文認為模型對這一問題的判斷不可靠,因此將其舍棄。相反,如果置信度不低于τ,論文則認為模型能夠解答該問題,并繼續構建對比樣本,具體步驟如下:
進一步在樣本上訓練M_plus使其能夠區分正確與錯誤的解決方案,從而得到一個更強的模型M_pro。
2.4 實驗
2.4.1 數據集
GSM8K和 MATH是兩個廣泛使用的數學推理數據集,其中 MATH 包含更具挑戰性的競賽問題。論文使用的數據統計信息如表 4 所示。特別是,為了確保弱模型有足夠的訓練數據來培養初步的數學技能,論文通過 Chern 等人(2023)構建的數據增強了 GSM8K 訓練集。
表 4:數據統計。Dg o l d, 1 和Dg o l d, 2 是訓練集的子集。弱模型使用Dg o l d, 1 來培養初始數學技能,而強模型只能訪問Dg o l d, 2 中的問題,沒有正確答案
圖4:第一階段的主要結果。第0^m 輪展示了兩個基線的性能,其中“weak”表示完全弱微調,即在全部弱數據上進行簡單微調,“icl”指的是不進行微調的弱ICL。連線表示模型共享相同的訓練數據源。低于“強上限”的結果顯示了通過貪婪解碼的測試準確率,而高于“強上限”的結果顯示了pass@k分數( k=10 和溫度=1.0 )。為簡潔起見,論文僅展示了通過貪婪解碼超越的Mhybrid-tad 檢查點的pass@k分數,完整結果在A.4.2 中提供
2.4.2實驗設置
論文使用Llama2-70b作為強模型,并采用來自不同家族的三種弱模型:Llama2-7b、Gemma-2b和Mistral-7b。論文對弱模型在D_gold,1 上進行全參數微調,并一致采用LoRA對強模型進行微調。在第一階段,論文根據迭代原則,在GSM8K上進行兩輪迭代,在MATH上進行一輪迭代。在第二階段,論文采用基于偏好學習的兩種方法,DPO及其變體ORPO。
論文在測試集上評估準確性。弱模型m 的性能與通過Dgold,2 中的黃金解決方案數據微調的強模型M 的性能相結合,代表了強模型與弱模型結合的最佳性能。
2.4.3 第一階段結果
GSM8K和MATH數據集上第一階段的主要結果如圖4所示。值得注意的是,在MATH實驗中,由于可用數據量較小,論文隨機抽取了未根據最終答案一致性選擇的數據。根據圖4,論文有以下觀察結果。
弱ICL微調顯示出顯著提升。使用論文提出的方法,僅由在GSM8K上準確率為25.17 的弱Gemma-2b監督的強模型性能,可以提升至60.12,超過簡單全弱微調31.08,并且超過Mplus(即Mhybrid-ft^2)。隨著弱模型的改進,這一結論在分類任務上得到了Liu和Alahi(2024)的驗證。具體而言,GSM8K上的性能逐漸提升,從Gemma-2b到Llama-7,再到Mistral-7b(25.17 -> 33.81 -> 59.51)。因此,通過這些模型生成的數據訓練的強模型的最大性能也逐步提升(60.12 -> 63.76 -> 68.39)。
Mhybrid-rt 實現了最高的 pass@k 分數。正如預期,Mhybrid-t 在弱到強設置中取得了最高的 pass@k 分數,這得益于其訓練數據融合了兩種類型的解決方案——一種來自弱模型,另一種來自強模型。這種多樣性通過降低過擬合的可能性增強了模型的魯棒性。此外,Mia-t 的表現通常優于 Mweak-ft,這可以歸因于過程級精度的變化和可能的解決方案格式。
簡單的微調不足以應對弱到強的推理任務。當使用 Gemma-2b 作為 MATH 數據集上的弱模型時,完全弱微調的表現不如弱基準(10.0 對比 11.6)。這表明,盡管簡單的微調在分類、國際象棋和獎勵建模任務中成功應用(Burns et al., 2023),但對于復雜的推理任務,尤其是像 MATH 中的高難度問題,這種方法顯得力不從心。相比之下,論文的弱-icl 微調方法有效地彌合了這一差距,為弱到強推理挑戰提供了一種有效且可擴展的解決方案。
ICL性能的影響 考慮到弱-icl微調的有效性部分取決于弱ICL的效果,論文進一步探討了通過謹慎選擇示例來增強ICL性能如何影響弱-icl微調的表現。圖5展示了使用Gemma-2b作為弱模型,在不同示例集下GSM8K測試的準確率。結果表明,使用這組特定示例的弱ICL性能從原始的56.48提高到了64.06。
隨后,論文在提示中使用這些示例重新生成Dicl,并在通過最終答案一致性精選的D?icl上微調強模型。這進一步將性能從64.06提升到64.75,證實了自主數據篩選的有效性。
值得注意的是,盡管弱ICL具有高性能的潛力,但在弱到強的框架中選擇有效的示例并非易事,這超出了本文的討論范圍。
2.4.4 第二階段結果
論文采用Mhybrid-ft的最終迭代作為Mplus進行后續的偏好學習。實驗結果驗證了該檢查點達到了更高的pass@k,并具有進一步提升的顯著潛力。
如表5所示,論文構建正負樣本的方法有效地增強了強模型的數學推理能力。在GSM8K上,DPO和ORPO使用論文構建的數據集都持續取得顯著改進,特別是在由Gemma-2b監督時,增加了8.49個百分點。盡管MATH問題本質上具有挑戰性,這影響了強模型的判斷并在訓練數據中引入了不準確性,但論文的方法通過ORPO仍然在MATH上取得了至少1個百分點的改進。
數據構建方法 在構建偏好數據時,論文始終使用由弱模型生成的弱響應作為被選擇/拒絕的響應之一,而不是完全依賴自生成的數據。論文還在GSM8K上使用Llama2-7b作為弱模型測試了自生成設置,其中被選擇和被拒絕的響應都由強模型自身生成。在這種設置下,DPO測試準確率為62.40(-0.22),表明性能略有下降。在沒有真實標簽的情況下,構建的正負樣本實際上分別對應于更頻繁和較少出現的答案,并與模型傾向于選擇的答案相關。由于偏好優化本質上執行排序,這種自生成設置的潛在收益是最小的。因此,在偏好數據構建過程中引入弱數據信號被證明是一種更好的方法。
2.4.5 分析
為進行進一步分析,論文檢查了MATH測試集中不同難度級別的準確率。
如圖6所示,強模型在較簡單的問題上表現出更好的泛化能力。具體來說,盡管Llama2-7b在1級問題上只達到了6.98點的準確率,但Llama2-70b在使用這種弱監督訓練后,可以在1級問題上達到超過30點的準確率。對于更具挑戰性的問題(4-5級),經ORPO增強的Mpro甚至超過了僅通過金標準解決方案監督微調獲得的強模型上限。這一現象驗證了從不正確數據中學習的有效性。
2.4.6 更接近未來場景的實驗
在對Llama3-70b(AI@Meta,2024)的初步測試中,論文觀察到在GSM8K和MATH上,Llama3-70b可以通過上下文學習在很大程度上釋放其潛力,而參數更新由于訓練不穩定性而產生邊際甚至負面影響。因此,論文聚焦于Llama3-70b發布后開發的更具挑戰性的數據集OlympicArena,以模擬更真實的未來場景。
論文僅考慮OlympicArena中的英語問題,排除了需要基于案例或專家評估的CODE(代碼生成)和OT(其他)問題類型。這樣得到了6,020個沒有解決方案和最終答案的訓練數據,以及313個有最終答案的測試數據,用于評估不同方法的性能。論文使用Llama3-8b-instruct(未在訓練數據子集上進行初始微調)作為弱模型,Llama3-70b作為待改進的強模型。超參數與GSM8K中使用的一致。這種配置更接近未來真實世界的弱到強場景。
實驗結果如表6所示。"Weak Floor"代表Llama3-8b-instruct的零樣本性能,"Full Weak FT"表示Llama3-70b在訓練集上由Llama3-8b-instruct生成的全部(即6,020個)弱解決方案上監督微調后的性能,"Weak ICL"表示Llama3-70b在Llama3-8b-instruct生成的4-shot弱示例下的性能。盡管參數更多,但由于挖掘能力不足,Llama3-70b在上下文學習下的表現仍低于Llama3-8b-instruct的零樣本性能。
通過論文提出的弱-icl微調方法獲得的M1 weak-ft,以更少的訓練數據(即746個)達到了比Full Weak FT更高的性能,超過了0.32個百分點。經過第二階段的偏好優化,進一步利用弱模型和沒有答案的訓練問題,強模型的性能比Full Weak FT又提高了3.19個百分點。這證明了論文的方法在更接近未來條件的場景中的穩健性和泛化能力。
本文轉載自 ??AI帝國??,作者: 無影寺
