推理模型的必經(jīng)之路-自適應推理
自適應推理模型的核心就是讓模型自己根據(jù)用戶問題的來判斷是否進行推理預測。
Arm存在三種格式:直接回答、短CoT或代碼、長CoT,同時引入Ada-GRPO解決傳統(tǒng) GRPO 中的格式崩潰問題。
除了自適應模式,Arm還支持另外兩種模式;
- 指令引導模式,用戶明確強制選擇某一種指定推理格式
- 共識引導模式,聚合直接回答、短CoT或代碼三種格式的輸出,當答案一致時,直接返回,否則認為任務較復雜,切換到Long CoT推理模式。
模型的訓練分為兩個階段,SFT和Ada-GRPO。
第一階段:SFT 推理格式理解
SFT作為冷啟動,讓模型可以用各種推理格式解決問題。
- 直接答案:直接給出答案,不進行任何推理鏈
<ANSWER>answer</ANSWER>
- 短CoT:先進行簡短的推理,然后給出答案
<COT>cot</COT><ANSWER>answer</ANSWER>
- 代碼:采用基于代碼的推理方式,格式:
<CODE>code</CODE><ANSWER>answer</ANSWER>
- 長CoT:涉及更詳細、迭代的推理過程,例如自我反思和替代方案生成等
<LONG_COT>cot</LONG_COT><ANSWER>answer</ANSWER>
模型訓練采用LlamaFactory框架,lora訓練,批次為128,學習率為 2e-4,采用余弦學習率調度器,6個epoch,10%步數(shù)預熱,訓練策略 ZeRO-3。
第二階段:Ada-GRPO訓練格式選擇
SFT 之后,模型會了使用多種推理格式進行回復,但無法根據(jù)任務自適應切換的能力,因此提出了自適應 GRPO,通過格式多樣性獎勵機制,讓模型能夠根據(jù)任務難度動態(tài)地選擇合適的推理格式。
最后,通過最大化以下目標函數(shù)來優(yōu)化模型:
結果
基座模型采用Qwen2.5-Base-3B、7B、14B模型。
SFT數(shù)據(jù)集,使用AQuA-Rat,由于僅存在直接答案和簡短CoT回答,利用GPT-4o和 DeepSeek-R1分別補充代碼和長CoT格式,過濾錯誤答案后,保留 3K 個多選題和 7.8K 個開放問題。
GPT-4o補充代碼
DeepSeek-R1補充長CoT
RL數(shù)據(jù)集,從簡單的常識推理到復雜的數(shù)學推理,包括 CommonsenseQA、GSM8K 和 MATH,總共包含 19.8K 條。
如下表所示,ARM的平均效果下降不到1%,但節(jié)省了超過30%的token。
同時,SFT只能讓模型學會格式,但沒辦法根據(jù)任務選擇合適的格式,而GRPO 確實提高了推理能力,但傾向于依賴長CoT來解決所有任務,如下圖所示。
比較自適應模式、指令引導模式、共識引導模式如下表所示,共識引導可以提高整體效果,但消耗token也更多。
驗證,自適應模式中格式的選擇不是隨機選擇,與指令引導模式上每種單獨模式比較,效果均好。
比較Ada-GRPO和GRPO,如下圖所示,在早期訓練步驟中Ada-GRPO由于選擇了次優(yōu)的推理格式,最初在準確率上落后于GRPO,但最終都收斂到相似的最終準確率。而Ada-GRPO最終將平均響應長度減少到大約GRPO的一半。
最后,想說,自適應推理應該推理模型的必經(jīng)之路,同時支持強制選擇推理模式也要支持,應用上,可以前置的就選擇強制指令,無法判斷的再讓大模型自己自適應。
本文轉載自??NLP工作站??,作者:NLP工作站
