「專業智能體指導」讓小模型學會數學推理!微調Mistral-7B實現86.81%準確率
對于小型語言模型(SLM)來說,數學應用題求解是一項很復雜的任務。
比如之前有研究結果顯示,在GSM 8K基準測試中實現80%以上準確度所需的最小模型尺寸為340億個參數。
為了在較小的模型上達到這種性能水平,研究人員經常訓練SLM來生成Python代碼或使用外部工具作為輔助,以避免計算錯誤。
或是基于集成(ensembling)技術,將100多個模型生成的輸出組合在一起,以獲得更準確的結果,最終結果的選擇需要通過共識、多數表決或與SLM結合使用的單獨的驗證器模型來完成,可以顯著提升準確率(Phi-GSM使用top-48將性能從68.2提升到81.5),不過代價是由于多次調用模型導致的成本顯著增加。
最近,微軟的研究人員提出了一個基于Mistral-7B、70億參數量的小型語言模型Orca-Math,它在GSM 8 k上實現了86.81%,不需要調用多個模型進行集成或使用驗證器、代碼執行或任何其他外部工具。
論文鏈接:??https://arxiv.org/abs/2402.14830??
Orca-Math的關鍵特性為:
1. 使用多個智能體(agent)創建出20萬個數學問題的高質量合成數據集,其中智能體合作創建數據;
2. 迭代學習技術,使SLM能夠練習解決問題,接收對其解決方案的反饋,并從包含SLM解決方案和反饋的偏好數據中學習。
當單獨使用有監督微調訓練時,Orca-Math在GSM 8 k pass@1指標上達到81.50%。通過迭代偏好學習,Orca-Math實現了86.81%的pass@1
Orca-Math超越了LLAMA-2- 70B,WizardMath-70B,Gemini-Pro,ChatGPT-3.5等更大型號的性能,在使用小得多的數據(數十萬對數百萬問題)時也顯著優于其他較小的模型。
數據集構造
種子集合
首先從現有的開源數據集中收集數學單詞問題樣本,即NumGLUE、AddSub、ALGES、ASDiv、DRAW、GSM8k、MATHQA、MultiArith、SingeOP、SingleEQ和SVAMP。
研究人員從Lila的訓練和驗證分裂中收集問題,以構建種子集,總共收集了36217個問題。
智能體 - ask me anything
通過從種子集中的問題創建多個單詞問題來擴展種子集,利用后續提示來創建問題。
智能體總共生成了120445個新問題,但所有生成的問題都表現出與種子詞問題相似的敘述方式,具體解決方案是使用GPT4-Trubo生成的。
智能體 - Suggester & Editor
通過解決具有挑戰性的問題進一步擴大種子集合。
為了實現這一點,研究人員引入了兩個新的智能體,即Suggester和Editor,可以協同工作以創建一個面向預定義目標的數據集:修改現有問題以增加其難度。
Suggester研究一個特定的問題,并提出了幾種在不產生實際問題的情況下提高其復雜性的方法。
Editor采用原始單詞問題和Suggester的建議,生成一個更新的、更具挑戰性的問題,迭代過程可以發生在多個回合中,每一回合都會進一步增加先前生成的問題的復雜性。
眼人員利用AutoGen框架來實現多智能體工作流。
對每個問題進行兩輪迭代,并過濾GPT4-Turbo生成的答案超過1800個字符的問題,最終收集了37157個問題。
訓練
有監督微調實驗(第一次迭代)
在Orca-Math-200K數據集上對Mistral-7B進行了微調,沒有使用packing,下面為具體的指令格式。
損失函數只基于答案token來計算。
正負信號的迭代學習
數據集構建(第二次迭代)
為了為每個問題生成額外的正樣本和負樣本,研究人員從第一次迭代的SFT調優模型中采樣四個回復。
具體來說,使用top_p=0.95和溫度=0.7,過程產生了一個數據集,其中200000個問題中的每個問題都有一個GPT4-Turbo生成的解決方案和四個學生生成的解決方法。
使用基于GPT4的精確匹配中定義的提示來評估教師(GPT4-Turbo)的答案和學生的答案之間的一致性。
對于學生生成的答案與老師的答案不匹配的所有解決方案,將其標記為負樣本。
數據集構建(第三次迭代)
為了從正反饋和負反饋中學習,研究人員評估了兩種算法的性能:直接偏好優化(DPO)和Kahneman-Tversky優化(KTO),還探索了KTO的功能,其區別在于只需要二進制「是」或「否」的回復來評估輸出的質量。
評估方法
研究人員使用精確匹配作為評估指標。
給定一個模型生成的答案,提示GPT-4來提取最終的簡短答案,并將其與金標準中的簡短答案進行匹配,即基于GPT4的精確匹配(GPT4-based-Exact-Match)。
實驗結果
研究人員測試了模型在包含1319個單詞問題的GSM8k測試集上幾個訓練過程的性能,對Mistral-7B模型進行了三次迭代的微調
在第一次迭代中,使用有監督微調來獲得M1;
第二次迭代中,對比了SFT、DPO和KTO,其中KTO訓練的模型在這一組中表現更好,獲得M2后,并使用M2生成迭代#3的數據集;
第三次迭代中,對比了DPO和KTO方法,使用M2作為模型起點。
研究人員還將這些模型與Orca-Math-200K數據集上經過三個epoch的SFT訓練進行了對比。
消融實驗
Model Generated Positives
通過將
限制為僅包含教師生成的解決方案來研究影響模型生成的正向因素(positives),換言之,研究人員移除在為迭代#2創建數據集時模型生成的所有
結果顯示,不管訓練算法如何,都會看到顯著的性能下降。
Synthetic Negatives
數據集的創建包括在M1或M2生成的所有四個回復都是positive的情況下的合成負樣本(negative creation)。
通過忽略問題qi來研究這些合成負樣本的影響,結果將第二次迭代的問題數量減少了約80k,將第三次迭代的問題數量增加了約104k
除GSM8k外的數學基準
研究人員還使用Orca Math其他幾個單詞問題數據集上進行了實驗,并且為了便于評估,最終選擇了問題答案都是單個數字的數據集。
評估指標為基于GPT4的精確匹配度量,并使用貪婪解碼生成模型回復。
沾染檢查(Contamination Check)
為了確保實驗的公正性,研究人員在文中表示:在訓練過程中,從未使用GSM8K或任何其他數據集的測試分割集,也從未將其用作合成問題生成的種子。
盡管如此,研究人員還是采用以下方法來檢測任何潛在的文本沾染(text contamination)問題:
1. 對文本進行預處理,包括將所有字符轉換為小寫、刪除標點符號、對文本進行分詞,以及刪除常見的英語停止詞,以確保數據的一致性。
2. 使用逆文檔頻率(TF-IDF)方法對文本語料庫進行矢量化,并確定測試集和訓練集之間的余弦相似性,從中為每個測試查詢選擇前k個(k=10)最相似的問題。
3. 通過計算在預設閾值0.5以上具有最高n-gram重疊的試題數量及其相應的訓練集匹配來評估文本污染的程度。
研究人員使用Jaccard相似度來計算文本對之間的n-gram重疊,并且為了進行嚴格的污染檢查,n設置為1。
需要注意的是,當使用Jaccard相似性測量時,n-gram重疊是n的非遞增函數。
4. 在執行算法時,確定表現出顯著的n-gram重疊的試題數量為8,因此根據定義的閾值,表明測試集中的文本污染可以忽略不計。
當將訓練集限制為僅包含種子問題時,表現出顯著n-gram重疊的測試問題的數量為7;并且在n≥2的情況下,表現出顯著的n-gram重疊的試題數為零。
本文轉自 新智元 ,作者:新智元
