增強大模型智能:數學推理能力的提升策略與實踐
一、大語言模型概述
首先來回顧一下大模型的基本結構。上圖中列出了當前一些主流大模型,比如 GPT 系列中的 GPT-3,發布于 2020 年,擁有 175B 參數,還有 Huggingface 的 Bloom、清華的 GLM 系列、Meta 的 LLaMA、百川的 Baichuan 和阿里的 Qwen 系列等等。除了清華的 GLM 使用的是 Prefix decoder,這些模型大多采用與 GPT 類似的架構。
這些模型的參數規模各不相同。GLM 系列除了最大 130B 的模型外,還有 6B 和 10B 的版本。Meta 的 LLaMA 系列有 65B 及其他不同規模的版本。千問系列有 7B、14B 和最大的 110B。這些開源模型為業界公司提供了很多優化的思路。
大模型的結構在業界已較為標準化,主要基于 transformer 結構。關鍵參數包括詞表、transformer 層數、Multi-head 和全連接層。以 GPT-2 為例,它是一個 1.3B 參數的模型,詞表大小 5 萬,層數 24 層。根據參數計算公式, Embedding 層的 d_model 為 2048,乘以 5 萬,得到其參數規模。QKV 計算、Attention Project 和 FFN 等參數加起來,最終得到 1.3B 的總參數。
大模型優化方面,常用的方法包括 SparseAttention、FlashAttention,以及其他結構如 MAQ 和 GQA 的優化,但整體結構仍基于 transformer。
大模型結構中,關鍵部分包括圖上面的 Multi-head 和下面的點積注意力計算,右側是大模型的總體結構示意。針對 Attention 的優化有 FlashAttention、SparseAttention 和 GQA,位置編碼有絕對位置編碼和 RoPE 等相對位置編碼。優化主要是為了提升大模型的外推能力,尤其對長文本效果更好。此外,還有對激活函數和其他細節的優化,業界在這些方向上都做了很多工作。
大語言模型的構建通常分為四個部分。以 OpenAI 為例:①預訓練,這是資源消耗最大的一環,通常使用 1000 多塊 GPU,訓練周期長,數據量達到數千億 token,約幾 TB;②SFT 層(有監督微調),主要優化指令對齊,數據量較少,通常為百萬級,少數達千萬級,訓練時間為天級別;③訓練獎勵模型(Reward Model);④人工反饋強化學習(RLHF),這部分的成本與 SFT 相似,但如果使用傳統 PPU 顯存占用較高,數據量允許天級別完成。
LLaMA 等模型也遵循類似流程,比如說 LLaMA2-Chat,分為預訓練、SFT、強化等階段,根據人類反饋調整指令偏好。
大模型的構建可以分為三個部分。第一階段是指令學習階段,通過訓練基座模型,使其理解人類指令,并根據人類編寫的指令和高質量回答進行 SFT。第二階段是讓大模型更擬人性或者是更符合人類偏好(人類對模型輸出進行偏好排序)。第三階段是人類反饋強化學習階段,是由四個模型構成:Reference Model(訓練好的參考模型)、Reward Model(對生成結果評分的模型)、Actor Model(需要強化的模型)和 Critique Model(訓練過程中的評分模型)。
前面回顧了大語言模型的基礎結構,接下來將介紹數學推理優化的流程,分為四塊:數據構建、數據篩選、模型構建、模型訓練與優化。數學推理的數據分為混合指令和合成數據。其中合成數據是對當前數據的擴展,因為高質量的數學數據,尤其是應用類指令較少。數據篩選,包括質量篩選和多樣性篩選,避免重復或相似問題。篩選原則依賴于 Reward Model 或 Critique Model。模型訓練使用 Reference Model,訓練好 SFT 后進行質量和多樣性篩選,歸為 RFT 流程,即拒絕采樣流程。在 Reward Model 或 Critique Model 中,使用 PPO、DPO 或 RFT 流程。
接下來將詳細介紹混合指令、合成數據和訓練優化的具體做法。
二、混合指令
數學問題可以拆解為邏輯推理和數學應用兩類。數學應用早期主要采用思維鏈(CoT)模式,后來為解決計算問題,引入了 PoT(Program-of-Thought)模式。當前的思路是數學分析或邏輯推理放到 CoT 部分處理,涉及計算的問題,如解方程或微積分計算,放到 PoT 部分。因此,混合指令由這兩部分構成。
這樣做有兩個原因。首先,CoT 并不擅長復雜運算,尤其是積分和方程運算。盡管大模型在預訓練中可以處理簡單運算(如三位數的加減乘除),但對于更高階的數學運算,PoT 的準確率更高,讓大模型專注于擅長的部分。
其次,單純使用 PoT 也有問題。在涉及需要推理的數學場景(如抽象代數和幾何運算)時,PoT 顯得不夠直觀,難以一步步推理。此外,它在整合前后邏輯關系時也存在問題。
所以現在我們使用的是混合指令。混合指令的前一部分是標準的 CoT 模式,比如 GPT-4o 的回答,前面的推理、中間的評分計算、合并同類項都是靠其數學推理能力一步步解決的。但我們發現最后的合并同類項出現了錯誤,前面的推理是完全正確的,公式引用也沒問題,但在數值計算方面有誤。
左側的方案將其拆分為:前面采用 CoT 思維鏈模式,類似于 GPT-4o,而在最后的計算部分,使用 PoT 來提高準確性。這個方法對于大模型的數學推理來說雖然不復雜,但確實簡單有效。
三、合成數據
接下來介紹合成數據。在預訓練時我們能獲得大量數學題目,但以英文為主。進行二階段 SFT 時,我們發現開源的好數據很少。常用的數據集如 GSM8K 和 MATH,雖然不錯,但數量有限。GSM8K 是小學數學推理題,MATH 類偏向競賽題。
大模型在解題時表現優秀,但讓它生成新問題則相對困難。這是因為解題需要的是運算能力,而生成新問題需要更高層次的思考和創造能力。
合成數據的 Self Instruct 是常用的方法,此方法早已提出。我們在種子任務中有部分高質量的數學問題集合,無論是購買的還是自建的。我們希望從這些高質量集合中擴展出更多樣化的數學指令。為此,將其細分為數學問題,按學科拆解,如矩陣運算、微積分、方程等。拆解后,再對每個子問題進行 Self Instruct,以擴展種子任務。篩選時,若只對指令篩選,可用最長公共子序列或 Jaccard 距離等簡單方法。
指令構建和篩選相對容易,但指令能否提供更多樣化的問題則是一個難點。有些解出的題目不適合作為訓練集,因此需嚴格把控指令和回答的質量。我們訓練過 Reward Model,最新的英偉達 340B 模型評分最高為 92 分,我們的模型為 86.8 分,排第五。86.8 分包括所有任務,如生成任務和翻譯任務。
針對數學類問題,我們理想的 Reward Model 評分分布應是正態分布,實際情況中,GPT-4o 評分在正確和錯誤回答間有明顯區分度,但我們 Reward Model 的評分分布不明顯。訓練時,Reward Model 對同一問題的正確和錯誤答案進行排序,而非絕對值評分。因此,Reward Model 能合理地對相同問題的生成進行排序,但不同問題間的絕對值評分參考意義不大。
在質量過濾時,不僅考慮相同問題,還要考慮不同問題之間的差異。因此,我們選擇了 Critique Model 進行絕對值打分。例如,左圖中,先用 Reward Model 對 n 個問題評分,取前 M 個高分,再用 Critique Model 從下往上卡絕對值。
Critique Model 的訓練如中圖所示:首先構建指令,明確角色;然后提供參考答案和模型回答;最后,GPT-4o 給出步驟和最終分值。
整個 Critique Model 訓練流程如下:從數據中提取問題和對應的參考答案,中間部分是標準指令,指導模型生成評判標準。最下面是 GPT-4o 或其他模型生成的打分結果。我們用這些數據訓練 Critique Model。GPT-4o 對問題的打分準確率為 85.94%,Critique Model 訓練后約為 84.76%。可以看到,GPT-4o 和 Critique Model 的最終打分分布差異明顯。
四、訓練優化
訓練分兩階段:RFT 階段和強化階段。
在 RFT 階段,我們采用這種方法有其背景。之前在大模型進行數學推理時發現,即使指令集不大,如果為每個問題生成多條不同的合理推理路徑,可以提升模型的多樣性和能力。因此,在 RFT 階段,我們先訓練一個 chat 模型,例如 LLaMA 進行 SFT 訓練。一階段訓練后的模型在二階段生成多條推理路徑,經過 Reward Model 和 Critique Model 的質量過濾和多樣性篩選。最終數據包含每個問題的多條推理路徑,再用于更大模型進行 RFT。
使用小模型生成和篩選數據,是因為大模型采樣成本過高。例如,10 萬條指令每條采樣 100 次,共生成 1000 萬條數據,用大模型成本較高,而小模型生成數據更節省時間成本,其生成的推理路徑更為多樣化。
上圖中可以反映出小模型的優勢,比如右上角的 LLaMA 模型,我們可以看到 33B、7B 和 13B 的模型,其中推理路徑貢獻最大的一部分并不是 33B,而是 7B。下面的圖也顯示,7B 和 14B 的模型分別貢獻了 41% 和 39% 的推理路徑,而中間兩個模型相交的推理路徑只有 19%。這說明更小的模型在數據生成和采樣方面,能得到更加多樣化的推理路徑。
整個 RFT 流程是使用較小的模型,例如我們會用 LLaMA 的小參數模型,來生成和過濾推理路徑,并進行多樣性選擇,然后再將這些數據用于更大的模型進行 RFT。質量過濾包括 Reward Model 打分和 Critique Model 打分,多樣性篩選是關鍵,因為重復的回答對大模型并不友好。
上圖中展示了詳細流程,比如左邊圖中的推理路徑由 r1 到 r3,再加入一個新路徑 r4。我們會計算 r1 到 r4 的相關性或距離,如果 r4 超過前兩個路徑的距離,就會替換其中一個,以保證選出路徑間距離最大化。在我們的流程中,重點在 PoT 部分的多樣性選擇。PoT 部分相對結構化,不同推理路徑會反映在 PoT 部分的不同實現方式上。
可以看一下,有三條路徑對應三個部分的 PoT。路徑一和路徑二在 PoT 部分看似不同,但只是注釋和變量命名不同。如果抽取關鍵信息,規范化變量命名并去掉冗余信息,會發現它們是完全相同的推理路徑。只有路徑三是真正不同的推理路徑,通過設未知數和方程來實現不同的推理。因此,設置關鍵信息抽取模塊,去掉冗余信息和規范化變量命名后,再計算相關性或距離度量,用作多樣性篩選的一個評判標準。
最終實驗結果顯示,我們的模型每次采樣 100 次,平均生成約 7.8 條推理路徑。
我們也評估了準確率。在一個評測集上,SFT 后的準確率為 71%,RFT 為 77%。但 DPO 部分沒有顯著提升。DPO 從 RFT 中采樣得分最高的答案(如九分)作為正例,得分最低的(如兩分)作為負例,并訓練 DPO 模型。訓練時加了輔助 loss 以與 reference 對齊。但九分和兩分的差距較大,DPO 能學到兩者的差異,但在難以區分的問題上優化效果不佳。
復盤發現,DPO 提升不明顯的主要原因:①在簡單問題上,答案更固定化,導致多樣性減少。②字數控制等方面做得更好,使得模型的分布更尖銳,logistic 輸出更精準,但對難題的優化效果有限。
我們對 DPO 部分進行了優化,不再用九分和兩分構建數據 pair 訓練 DPO 模型,而是使用一些難以區分的問題。例如,數學中的精度控制問題,CoT 錯誤但 PoT 正確的問題,或多步 PoT 的難題。我們將這些難以通過 SFT 解決的案例放入 DPO。
我們做了兩部分優化:PPO 和 DPO。最終效果顯示,DPO 勝率為 17%,負率為 10%,差距為 7%;而 PPO 的差距僅為 1%。上圖中右邊的案例顯示,SFT 難以解決的問題在 DPO 后確定性更好,減少了生成的隨機性。
這是一個早期的工作,講述了為什么要使用動態 loss。我們發現簡單的數學或邏輯推理,7B 或 10B 模型就能很好地解決。在訓練初期,準確率在前兩個樣本達到峰值,后續訓練效果不明顯。
而 hard sample 則需要更多輪訓練才能收斂。舉例來說,從前面 233 個 step 到最后 2047 個 step,loss 在后期才平緩。我們定義 hard sample 為模型有十條推理路徑,但 Critique Model 打分準確率低于 50% 的問題。對這些 hard sample,特別是 PoT 部分,進行動態 loss 加權。
我們一直在進行數學推理的研究,作為大模型通用能力的一部分。上圖中展示了今年 4 月的 Superclue 評測數據,這是一個閉源的第三方評測,看不到具體問題。數據顯示 GPT-4-Turbo-0125 擁有最佳表現(GPT-4o 尚未推出),國內大模型中成績最好的是從容大模型,接著就是 360gpt-pro,得分為 75.5 分。
以上就是本次分享的內容,謝謝大家。
五、問答環節
Q1:之前提到的 DPO 和 PPO 是基于兩個測試集的結果,還是在兩個不同的問題領域中的表現?另外,這兩個方法之間存在什么主要差異?
A1:那個評測是在一個評測集上的,都是數學推理類的問題。我們做了兩部分工作,一部分是 PPO,另一部分是 DPO。當時在構建 pair 對時,是根據 RFT 的最高得分和最低得分來構建的。這部分數據是重新構建的。
Q2:關于您們的合成數據工作,包括最近其他的合成數據研究,比如騰訊的 10 億人設研究。您覺得為什么這種合成數據能在復雜推理任務中發揮作用?另外,您認為合成數據在復雜推理任務中的上限是什么?因為看騰訊的研究,Scaling 曲線表現很好。
A2:這個問題很好,也是我們目前在做的,我們數據組尤其關注合成數據。為什么要做合成數據?因為現有指令少,尤其是數學類的。我們需要更多的指令,同時要提高指令的難度。比如,現有的 GSM8K 和 MAS 類指令只能擴展到小學數學應用和競賽題目,這在多樣性和難度上都有問題。我們的做法是將問題細分為數學應用類、矩陣運算類、積分類等子類。每個子類下由標注人員構建種子指令,然后再進行數據合成。第一步必須做到位,第二步才能有效。
合成數據在復雜推理任務的天花板在于篩選邏輯。如果篩選機制好,生成模型足夠優秀,就能生成更好的指令。要對指令進行關鍵詞抽取,再根據 token 級別擴展,生成的指令才會更好。篩選機制也很重要,不僅要篩選好的指令,還要篩選指令的回答,這兩者決定了天花板的高度。
英偉達的研究也展示了合成數據的重要性。只有 2 萬條數據是人工標注的,98% 是合成數據。他們的篩選方法尤其對 MAS 類問題進行了分類,但主要針對簡單問題,像 GSM8K 的簡單替換。而在數學推理外,如 close QA 或 open QA 類問題,英偉達的方法可能會生成與原數據分布相似的數據,這不是我們想要的。我們需要分布之外的數據,有擴展性的合成數據。英偉達還注重 reward model 的訓練,特別是 340B 的 reward model,這部分工作在于區分難分的指令。因此,合成數據需要細分領域或技能,最終的質量和多樣性決定了效果。
Q3:老師您好,我們看到 Critique Model 和 GPT-4o 的打分分布已經接近,Critique Model 的大小是否考慮了不同參數量的影響?您提到生成樣本數據時會用一個特別小的模型,所以判別模型也會很小,但英偉達的 reward model 很大。
A3:Critique Model 比 reference model 小很多。Critique Model 和 reward model 不同,reward model 很大,但Critique Model 不能太大。reward model 推理速度快很多,但它是二分類模型;而 Critique Model 是語言模型,兩者屬于不同類型的模型。