強(qiáng)化微調(diào) ReFT:開啟大語言模型推理新范式
大家好,我是肆〇柒。因?yàn)榕c合作伙伴項(xiàng)目的需要,最近對 RL 方面的論文關(guān)注的多了一些。這兩天,我翻出一篇去年的論文來復(fù)習(xí)。這篇是來自字節(jié)跳動研究團(tuán)隊(duì)(ByteDance Research)的 ACL 2024 論文《ReFT: Reasoning with Reinforced Fine-Tuning》。這篇論文發(fā)表在《Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)》上。
在人工智能領(lǐng)域,提升大語言模型(LLM)的數(shù)學(xué)推理能力一直是研究熱點(diǎn)。然而,現(xiàn)有的監(jiān)督微調(diào)(SFT)方法結(jié)合思維鏈(CoT)注釋在泛化能力上存在明顯瓶頸。為解決這一問題,字節(jié)跳動研究團(tuán)隊(duì)提出了一種名為 ReFT(Reasoning with Reinforced Fine-Tuning)的創(chuàng)新方法,通過強(qiáng)化學(xué)習(xí)機(jī)制,使模型能夠探索多種推理路徑,從而顯著提升其在數(shù)學(xué)問題求解任務(wù)中的推理能力和泛化性能。
傳統(tǒng) SFT 方法僅依賴單一正確的推理路徑進(jìn)行訓(xùn)練,導(dǎo)致模型在面對多樣化問題時(shí)泛化能力不足。例如,在 GSM8K 數(shù)據(jù)集上,基于 SFT 的模型在某些復(fù)雜問題上表現(xiàn)不佳,準(zhǔn)確率難以突破瓶頸。這種局限性促使研究者探索新的微調(diào)范式,以充分挖掘模型的推理潛力。
下圖展示了 GSM8K 數(shù)據(jù)集中的一道示例題目及其 CoT 和答案,清晰地說明了監(jiān)督微調(diào)和強(qiáng)化微調(diào)的對比。通過這種對比,我們可以更好地理解 ReFT 如何在訓(xùn)練過程中利用多種推理路徑來提升模型的性能。
GSM8K 數(shù)據(jù)集示例題目及其 CoT 和答案
數(shù)學(xué)問題解決中,單一正確推理路徑的依賴,成為模型泛化的主要障礙。實(shí)際上,許多數(shù)學(xué)問題存在多種有效的推理路徑,模型若能學(xué)習(xí)這些路徑,將大幅提升其泛化能力。ReFT 方法被提出,它突破了傳統(tǒng)微調(diào)范式的限制,通過強(qiáng)化學(xué)習(xí)機(jī)制,使模型能夠探索多種推理路徑,從而增強(qiáng)其推理深度與準(zhǔn)確性。
ReFT 方法概述
ReFT 的核心在于兩階段訓(xùn)練框架。
首先,通過監(jiān)督微調(diào)(SFT)對模型進(jìn)行初始化,使其具備基本的數(shù)學(xué)問題求解能力。接著,利用強(qiáng)化學(xué)習(xí)(特別是 PPO 算法)對模型進(jìn)行進(jìn)一步優(yōu)化。在強(qiáng)化學(xué)習(xí)階段,模型能夠自動采樣多種推理路徑,并基于真實(shí)答案獲得獎勵信號,從而不斷調(diào)整策略,提升推理能力。相比傳統(tǒng) SFT,ReFT 預(yù)期在泛化能力上實(shí)現(xiàn)顯著提升,同時(shí)優(yōu)化模型的推理深度與準(zhǔn)確性。
下圖對比了 SFT 和 ReFT 在存在 CoT 替代方案時(shí)的表現(xiàn),直觀地展示了 ReFT 如何通過探索多種推理路徑來提升模型的性能。
SFT 和 ReFT 在 CoT 替代方案上的對比
ReFT 方法論
監(jiān)督微調(diào)(SFT)準(zhǔn)備階段
在 SFT 階段,數(shù)據(jù)集的選擇與標(biāo)注質(zhì)量至關(guān)重要。GSM8K、SVAMP、MathQA 數(shù)據(jù)集因其題目類型的多樣性和標(biāo)注的規(guī)范性,成為理想的訓(xùn)練數(shù)據(jù)源。以 GSM8K 數(shù)據(jù)集為例,其包含 8K 道數(shù)學(xué)應(yīng)用題,每道題都配有詳細(xì)的思維鏈(CoT)注釋,涵蓋從簡單算術(shù)到復(fù)雜代數(shù)的多種類型,為模型訓(xùn)練提供了豐富的樣本。
模型預(yù)訓(xùn)練基礎(chǔ)的選擇同樣關(guān)鍵。研究團(tuán)隊(duì)將 CodeLLAMA 和 Galactica 作為基礎(chǔ)模型,其預(yù)訓(xùn)練特性與數(shù)學(xué)推理任務(wù)高度契合。CodeLLAMA 在代碼生成任務(wù)上的優(yōu)勢,使其能夠更好地理解數(shù)學(xué)問題中的邏輯結(jié)構(gòu);而 Galactica 在科學(xué)文獻(xiàn)處理上的專長,則有助于模型對數(shù)學(xué)問題中專業(yè)術(shù)語的理解。SFT 初始化策略,如學(xué)習(xí)率的設(shè)置、預(yù)訓(xùn)練權(quán)重的加載方式等,對后續(xù)強(qiáng)化學(xué)習(xí)階段的學(xué)習(xí)效果有著深遠(yuǎn)影響。
SFT 的訓(xùn)練目標(biāo)函數(shù)基于交叉熵?fù)p失,通過最小化模型預(yù)測與真實(shí) CoT 標(biāo)注之間的差異,使模型逐步掌握數(shù)學(xué)問題的基本解題思路。訓(xùn)練過程中的收斂性判斷標(biāo)準(zhǔn),如連續(xù)多個(gè) epoch 驗(yàn)證損失不再下降,則表明模型在當(dāng)前數(shù)據(jù)集上已達(dá)到較好的擬合效果,可進(jìn)入強(qiáng)化學(xué)習(xí)階段。
ReFT 強(qiáng)化學(xué)習(xí)階段
ReFT 強(qiáng)化學(xué)習(xí)階段采用 PPO(Proximal Policy Optimization)算法,這是一種在策略梯度方法基礎(chǔ)上改進(jìn)的強(qiáng)化學(xué)習(xí)算法,具有穩(wěn)定性和高效性優(yōu)勢。PPO 算法通過限制策略更新的幅度,避免了策略梯度方法中常見的訓(xùn)練不穩(wěn)定問題。在 ReFT 的應(yīng)用場景下,PPO 算法的參數(shù)調(diào)整需根據(jù)數(shù)學(xué)問題的特點(diǎn)進(jìn)行優(yōu)化,例如學(xué)習(xí)率的設(shè)置、折扣因子 γ 的選擇等。
PPO 算法的具體運(yùn)算過程如下:
1. 策略網(wǎng)絡(luò)構(gòu)建:策略網(wǎng)絡(luò)采用多層感知機(jī)(MLP)結(jié)構(gòu),輸入為問題狀態(tài),輸出為動作概率分布。例如,對于一個(gè)數(shù)學(xué)問題求解任務(wù),策略網(wǎng)絡(luò)的輸入可以是問題的文本編碼,輸出則是下一步推理動作的概率分布。
2. 價(jià)值函數(shù)估計(jì):價(jià)值函數(shù)用于估計(jì)當(dāng)前狀態(tài)下的期望累計(jì)獎勵。通過訓(xùn)練一個(gè)價(jià)值網(wǎng)絡(luò),使用均方誤差損失函數(shù)來擬合真實(shí)價(jià)值函數(shù)。價(jià)值網(wǎng)絡(luò)的輸入與策略網(wǎng)絡(luò)相同,輸出為一個(gè)標(biāo)量值,表示當(dāng)前狀態(tài)的價(jià)值。
3. 優(yōu)勢函數(shù)計(jì)算:優(yōu)勢函數(shù)衡量在當(dāng)前狀態(tài)下采取特定動作相對于平均策略的優(yōu)劣。計(jì)算公式為:
4. 策略更新:根據(jù)采樣的軌跡計(jì)算優(yōu)勢函數(shù)估計(jì)值,使用 PPO 的裁剪目標(biāo)函數(shù)更新策略網(wǎng)絡(luò)參數(shù)。裁剪目標(biāo)函數(shù)為:
從單一問題中采樣多種推理路徑是 ReFT 的關(guān)鍵創(chuàng)新之一。基于策略梯度的路徑探索機(jī)制,模型能夠在給定問題時(shí)生成多種可能的推理路徑。通過多樣性采樣技術(shù),如溫度調(diào)節(jié)(temperature scaling)、核采樣(top-k sampling)等,模型能夠生成具有多樣性的路徑集合。隨后,利用篩選機(jī)制,如基于答案正確性的過濾、基于路徑相似度的去重等,保留有效的推理路徑,從而豐富模型的學(xué)習(xí)樣本。
獎勵信號的設(shè)計(jì)直接關(guān)系到模型的學(xué)習(xí)效果。ReFT 的獎勵函數(shù)以真實(shí)答案為核心,當(dāng)模型生成的推理路徑得出正確答案時(shí),給予正向獎勵;否則,給予懲罰。部分獎勵策略在稀疏反饋環(huán)境中發(fā)揮著重要作用,例如在數(shù)學(xué)問題的中間步驟給予一定獎勵,引導(dǎo)模型逐步接近正確答案,從而緩解了強(qiáng)化學(xué)習(xí)中常見的稀疏獎勵問題。
下圖展示了 MathQAMCQ 數(shù)據(jù)集中的一個(gè)示例預(yù)測,展示了獎勵欺騙現(xiàn)象。當(dāng)模型生成錯誤的推理路徑卻得出正確答案時(shí),會獲得不當(dāng)獎勵,誤導(dǎo)模型的學(xué)習(xí)方向。這種現(xiàn)象在多選題場景下尤為突出,嚴(yán)重時(shí)可能導(dǎo)致模型性能下降。ReFT 通過合理設(shè)計(jì)獎勵函數(shù)和采樣策略,在一定程度上緩解了獎勵欺騙問題,確保了訓(xùn)練過程的可靠性。
MathQAMCQ 數(shù)據(jù)集示例預(yù)測,揭示獎勵欺騙現(xiàn)象
ReFT 關(guān)鍵機(jī)制深度解析
線上強(qiáng)化學(xué)習(xí)與自監(jiān)督學(xué)習(xí)在 ReFT 中相輔相成。線上強(qiáng)化學(xué)習(xí)使模型能夠?qū)崟r(shí)根據(jù)環(huán)境反饋調(diào)整策略,而自監(jiān)督學(xué)習(xí)則利用模型自身生成的數(shù)據(jù)進(jìn)行進(jìn)一步學(xué)習(xí),兩種范式的協(xié)同作用顯著提升了模型的泛化能力。例如,在處理復(fù)雜的代數(shù)問題時(shí),模型通過線上強(qiáng)化學(xué)習(xí)不斷嘗試不同的解題思路,同時(shí)借助自監(jiān)督學(xué)習(xí)對生成的推理路徑進(jìn)行自我評估與優(yōu)化,從而逐步掌握問題的解題規(guī)律。
部分獎勵策略與 KL 散度約束的平衡機(jī)制是 ReFT 的另一關(guān)鍵。部分獎勵在不同推理階段的合理應(yīng)用,如在問題初期給予較高的探索獎勵,隨著推理深入逐步增加開發(fā)獎勵,能夠引導(dǎo)模型在探索與利用之間取得平衡。KL 散度約束則通過限制新舊策略之間的差異,防止模型在強(qiáng)化學(xué)習(xí)過程中偏離初始策略過遠(yuǎn),從而保證了訓(xùn)練的穩(wěn)定性。這種平衡機(jī)制的動態(tài)調(diào)整,使模型能夠在復(fù)雜多變的數(shù)學(xué)問題中保持穩(wěn)定的性能提升。
ReFT 支持自然語言 CoT 與程序基 CoT 的雙重處理框架。自然語言 CoT 以自然語言形式描述推理過程,易于人類理解和解釋;而程序基 CoT 則以編程語言形式表達(dá),具有更高的精確性和可執(zhí)行性。ReFT 的融合處理框架能夠充分利用兩種 CoT 形式的優(yōu)點(diǎn),增強(qiáng)模型在不同場景下的適用性與魯棒性。例如,在處理涉及邏輯判斷與循環(huán)操作的數(shù)學(xué)問題時(shí),程序基 CoT 能夠提供更清晰的執(zhí)行步驟,而自然語言 CoT 則有助于模型理解問題背景與上下文信息。
與離線自訓(xùn)練和在線自訓(xùn)練方法相比,ReFT 具有顯著優(yōu)勢。離線自訓(xùn)練受限于初始采樣數(shù)據(jù)的質(zhì)量與多樣性,難以動態(tài)調(diào)整訓(xùn)練策略;在線自訓(xùn)練則存在反饋延遲問題,影響模型的實(shí)時(shí)學(xué)習(xí)效果。ReFT 的即時(shí)反饋與動態(tài)調(diào)整機(jī)制使其能夠在訓(xùn)練過程中快速適應(yīng)問題的復(fù)雜性,從而實(shí)現(xiàn)更高效的性能提升。
SFT 方法在數(shù)學(xué)問題求解中的局限性主要體現(xiàn)在其對單一正確推理路徑的依賴。例如,當(dāng)面對具有多種解題方法的數(shù)學(xué)問題時(shí),SFT 模型往往只能學(xué)習(xí)到其中一種方法,導(dǎo)致其在面對其他解題思路時(shí)泛化能力不足。ReFT 通過強(qiáng)化學(xué)習(xí)機(jī)制,使模型能夠探索多種推理路徑。例如,在 GSM8K 數(shù)據(jù)集上,ReFT 能夠通過采樣不同的推理路徑,逐步學(xué)習(xí)到多種解題方法,從而克服 SFT 方法的局限性,提升模型的泛化能力和推理深度。
實(shí)驗(yàn)設(shè)計(jì)與結(jié)果評估
實(shí)驗(yàn)環(huán)境與配置
實(shí)驗(yàn)基于 GSM8K、SVAMP 和 MathQA 三大數(shù)據(jù)集展開,這些數(shù)據(jù)集在數(shù)學(xué)問題求解研究中具有代表性,涵蓋了從基礎(chǔ)算術(shù)到高級代數(shù)的廣泛問題類型。例如,SVAMP 數(shù)據(jù)集包含 3,000 多道經(jīng)過嚴(yán)格篩選的數(shù)學(xué)題,題目難度適中且具有良好的代表性。下表提供了訓(xùn)練集和測試集的統(tǒng)計(jì)信息,展示了數(shù)據(jù)集的規(guī)模和特性。
訓(xùn)練集和測試集的統(tǒng)計(jì)信息
基礎(chǔ)模型選擇 CodeLLAMA 和 Galactica,主要考慮其架構(gòu)特點(diǎn)與數(shù)學(xué)推理任務(wù)的適配性。CodeLLAMA 的 decoder-only 架構(gòu)使其在生成任務(wù)上具有高效性,而 Galactica 的 large context window 特性能夠處理較長的數(shù)學(xué)問題描述。訓(xùn)練硬件環(huán)境采用 8 塊 A100-80GB GPU,配合 DeepSpeed Zero stage 2 和 HuggingFace Accelerate,確保了訓(xùn)練過程的高效性與穩(wěn)定性。
在實(shí)驗(yàn)中,ReFT 方法與多種基線方法進(jìn)行了對比,包括 SFT、離線自訓(xùn)練和在線自訓(xùn)練。SFT 作為傳統(tǒng)方法,直接利用標(biāo)注數(shù)據(jù)進(jìn)行監(jiān)督訓(xùn)練;離線自訓(xùn)練通過初始模型生成額外樣本進(jìn)行訓(xùn)練;在線自訓(xùn)練則在訓(xùn)練過程中動態(tài)生成樣本。為確保公平比較,所有基線方法均采用相同的超參數(shù)調(diào)整策略,如學(xué)習(xí)率、批次大小等,并通過交叉驗(yàn)證評估性能穩(wěn)定性。
實(shí)驗(yàn)結(jié)果呈現(xiàn)與分析
下表展示了 ReFT 和基線方法在所有數(shù)據(jù)集上的價(jià)值準(zhǔn)確率。在 GSM8K 數(shù)據(jù)集上,ReFT 的自然語言 CoT 準(zhǔn)確率達(dá)到 75.28%,程序基 CoT 準(zhǔn)確率更是高達(dá) 81.2%,相比 SFT 方法分別提升了近 12 個(gè)百分點(diǎn)和 17 個(gè)百分點(diǎn)。在 SVAMP 數(shù)據(jù)集上,ReFT 的準(zhǔn)確率提升了約 10 個(gè)百分點(diǎn)。這些結(jié)果表明 ReFT 在不同數(shù)據(jù)集上均能顯著超越基線方法,展現(xiàn)出卓越的推理性能。
ReFT 和基線方法在所有數(shù)據(jù)集上的價(jià)值準(zhǔn)確率
下表針對 MathQAnumeric 基準(zhǔn)測試,進(jìn)一步驗(yàn)證了 ReFT 的魯棒性。ReFT 在該變種數(shù)據(jù)集上的準(zhǔn)確率達(dá)到 78.0%,相比 SFT 提升了近 15 個(gè)百分點(diǎn)。這表明 ReFT 在處理數(shù)值型答案的數(shù)學(xué)問題時(shí),能夠有效避免獎勵欺騙問題,保持穩(wěn)定的性能表現(xiàn)。
ReFT 和基線方法在 MathQAnumeric 基準(zhǔn)測試上的價(jià)值準(zhǔn)確率
下表則凸顯了多數(shù)投票與重排序技術(shù)對 ReFT 性能的顯著增益效果。結(jié)合多數(shù)投票策略后,ReFT 在 GSM8K 數(shù)據(jù)集上的準(zhǔn)確率提升了 8.6 個(gè)百分點(diǎn);而在重排序技術(shù)的助力下,準(zhǔn)確率提升了超過 3 個(gè)百分點(diǎn)。這些結(jié)果充分證明了 ReFT 與這些技術(shù)的兼容性,能夠通過集成方法進(jìn)一步提升模型的性能。
多數(shù)投票和重排序技術(shù)對 SFT 和 ReFT 在 GSM8K 數(shù)據(jù)集上的解題準(zhǔn)確率影響
下圖展示了 ReFT 在 GSM8K P-CoT 數(shù)據(jù)集上的訓(xùn)練獎勵、評估準(zhǔn)確率和 KL 散度隨訓(xùn)練周期的變化情況。從圖中可以看出,隨著訓(xùn)練的進(jìn)行,ReFT 的評估準(zhǔn)確率穩(wěn)步提升,同時(shí) KL 散度逐漸趨于穩(wěn)定,反映了 ReFT 在強(qiáng)化學(xué)習(xí)階段的訓(xùn)練動態(tài)過程和穩(wěn)定性。
ReFT 在 GSM8K P-CoT 數(shù)據(jù)集上的訓(xùn)練獎勵、評估準(zhǔn)確率和 KL 散度變化情況
下表的消融研究結(jié)果進(jìn)一步量化了 ReFT 各個(gè)關(guān)鍵組件的貢獻(xiàn)。例如,當(dāng)移除部分獎勵策略時(shí),ReFT 在 GSM8K P-CoT 任務(wù)上的準(zhǔn)確率從 81.2% 下降至 80.2%;而將 KL 系數(shù) β 設(shè)置為 0 時(shí),模型性能出現(xiàn)嚴(yán)重退化,準(zhǔn)確率幾乎降為 0。這些結(jié)果凸顯了部分獎勵策略和 KL 散度約束在維持 ReFT 穩(wěn)定性和性能方面的重要作用。
消融研究結(jié)果
下圖比較了 SFT 和 ReFT 在不同預(yù)熱 epoch 數(shù)下的準(zhǔn)確率。結(jié)果顯示,ReFT 在經(jīng)過適當(dāng)?shù)念A(yù)熱步驟后,性能顯著優(yōu)于 SFT,尤其是在預(yù)熱 epoch 為 3 和 5 時(shí),ReFT 的準(zhǔn)確率提升最為明顯。
不同預(yù)熱 epoch 數(shù)下 SFT 和 ReFT 的準(zhǔn)確率對比
下圖展示了 SFT 和 ReFT 模型在 GSM8K 數(shù)據(jù)集上同一問題的不同訓(xùn)練周期的 P-CoT 響應(yīng)。綠色框架表示正確的響應(yīng),紅色框架表示錯誤的響應(yīng)。從圖中可以看出,ReFT 在訓(xùn)練過程中逐漸收斂到正確的解題路徑,而 SFT 則在多個(gè)訓(xùn)練周期中表現(xiàn)不穩(wěn)定。
GSM8K 數(shù)據(jù)集上同一問題在不同訓(xùn)練周期的 P-CoT 響應(yīng)對比
結(jié)果分析與洞察
ReFT 在不同數(shù)據(jù)集上的性能提升呈現(xiàn)出一些共性規(guī)律。例如,在涉及多步推理的復(fù)雜問題上,ReFT 的性能提升更為顯著,這歸因于其能夠探索多種推理路徑,從而更好地應(yīng)對問題的復(fù)雜性。同時(shí),數(shù)據(jù)集的特性也對性能提升產(chǎn)生影響。在 GSM8K 數(shù)據(jù)集上,由于問題類型的多樣性,ReFT 能夠充分利用其路徑探索能力,實(shí)現(xiàn)顯著的性能提升。而在 SVAMP 數(shù)據(jù)集上,由于部分問題存在固定的解題模板,ReFT 的提升幅度相對較小,但仍優(yōu)于基線方法。
小模型實(shí)驗(yàn)進(jìn)一步驗(yàn)證了 ReFT 的泛化能力。即使在參數(shù)量較少的模型上,ReFT 仍能取得優(yōu)于 SFT 的結(jié)果。例如,在 Galactica-125M 模型上,ReFT 在 GSM8K 數(shù)據(jù)集上的準(zhǔn)確率相比 SFT 提升了近 6 個(gè)百分點(diǎn)。這表明 ReFT 方法具有良好的普適性,能夠適應(yīng)不同規(guī)模的模型。
總體而言,實(shí)驗(yàn)結(jié)果充分證明了 ReFT 方法在提升大語言模型數(shù)學(xué)推理能力方面的顯著優(yōu)勢,為未來推理任務(wù)的研究和實(shí)踐提供了新的方向和思路。
實(shí)踐指南與代碼實(shí)現(xiàn)
環(huán)境搭建步驟
搭建 ReFT 的運(yùn)行環(huán)境,首先需安裝依賴庫,包括 transformers、torch、accelerate 等。各庫的版本需滿足兼容性要求,例如 transformers 版本應(yīng)與基礎(chǔ)模型的實(shí)現(xiàn)相匹配。以下是具體的安裝命令:
pip install transformers==4.28.0 torch==1.13.1 accelerate==0.16.0
數(shù)據(jù)預(yù)處理流程涉及將原始數(shù)據(jù)集轉(zhuǎn)換為模型可接受的格式,如將 GSM8K 數(shù)據(jù)集中的問題、CoT 和答案整理為 JSON 格式。數(shù)據(jù)格式規(guī)范對模型訓(xùn)練至關(guān)重要,不正確的格式可能導(dǎo)致訓(xùn)練過程中的錯誤。
SFT 實(shí)現(xiàn)詳解
train_sft_model.py
腳本是 SFT 的核心實(shí)現(xiàn)。其關(guān)鍵參數(shù)包括學(xué)習(xí)率、批次大小、訓(xùn)練 epoch 數(shù)等。例如,學(xué)習(xí)率設(shè)置為 1e-5,批次大小為 48,訓(xùn)練 epoch 數(shù)為 40。這些參數(shù)的選擇基于實(shí)驗(yàn)經(jīng)驗(yàn)和數(shù)據(jù)集特性,對 SFT 的訓(xùn)練效果有著直接的影響。
在訓(xùn)練過程中,需監(jiān)控?fù)p失變化和驗(yàn)證集準(zhǔn)確率等關(guān)鍵指標(biāo)。可以通過 TensorBoard 進(jìn)行可視化,具體命令如下:
tensorboard --logdir=./logs
當(dāng)驗(yàn)證集準(zhǔn)確率 plateau 時(shí),可以嘗試調(diào)整學(xué)習(xí)率或增加正則化。例如,將學(xué)習(xí)率降低一個(gè)數(shù)量級:
optimizer = AdamW(model.parameters(), lr=1e-6)
ReFT 代碼實(shí)戰(zhàn)
train_rl_reft.py
腳本實(shí)現(xiàn)了 ReFT 的強(qiáng)化學(xué)習(xí)流程。以下是 PPO 算法的關(guān)鍵代碼片段:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
classPPO:
def__init__(self, model, lr, gamma, epsilon, device):
self.model = model
self.optimizer = optim.Adam(model.parameters(), lr=lr)
self.gamma = gamma
self.epsilon = epsilon
self.device = device
defcompute_advantages(self, rewards, values):
advantages = []
gae = 0
for t inreversed(range(len(rewards))):
delta = rewards[t] + self.gamma * values[t+1] - values[t]
gae = delta + self.gamma * gae
advantages.insert(0, gae)
return advantages
defupdate(self, states, actions, rewards, log_probs_old):
states = torch.tensor(states, dtype=torch.float32).to(self.device)
actions = torch.tensor(actions, dtype=torch.int64).to(self.device)
rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
log_probs_old = torch.tensor(log_probs_old, dtype=torch.float32).to(self.device)
# 計(jì)算價(jià)值函數(shù)
values = self.model.value(states)
# 計(jì)算優(yōu)勢函數(shù)
advantages = self.compute_advantages(rewards, values)
advantages = torch.tensor(advantages, dtype=torch.float32).to(self.device)
# 計(jì)算新策略的概率分布
logits = self.model.policy(states)
dist = Categorical(logits=logits)
log_probs_new = dist.log_prob(actions)
# 計(jì)算 PPO 裁剪目標(biāo)函數(shù)
ratio = torch.exp(log_probs_new - log_probs_old)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 計(jì)算價(jià)值函數(shù)損失
value_loss = nn.MSELoss()(values, rewards)
# 更新模型
self.optimizer.zero_grad()
policy_loss.backward()
value_loss.backward()
self.optimizer.step()
在強(qiáng)化學(xué)習(xí)訓(xùn)練過程中,調(diào)試技巧至關(guān)重要。例如,通過打印中間策略分布、獎勵值等信息,診斷采樣多樣性不足、獎勵稀疏等問題,并據(jù)此調(diào)整采樣溫度、獎勵函數(shù)參數(shù)等。常用的調(diào)試工具有 TensorBoard(用于可視化訓(xùn)練指標(biāo))、PyTorch 的斷點(diǎn)調(diào)試功能等。
采樣與評估實(shí)踐
sampling.py
提供了多種采樣策略配置,如溫度采樣、核采樣、束搜索等。以下是一個(gè)溫度采樣的實(shí)現(xiàn)示例:
def temperature_sampling(logits, temperature):
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
return probs
不同采樣策略適用于不同場景,例如,在探索階段可采用較高的溫度值以增加采樣多樣性;而在開發(fā)階段則可降低溫度值以聚焦于高概率路徑。采樣參數(shù)的調(diào)整對結(jié)果多樣性有顯著影響,較高的溫度值會產(chǎn)生更多樣化的路徑,但也可能引入更多噪聲。
重排序模型的訓(xùn)練基于生成的多個(gè) CoT 樣本,通過訓(xùn)練二分類器判斷樣本的正確性,從而實(shí)現(xiàn)對 CoT 的重排序。模型集成策略,如將多個(gè)重排序模型的預(yù)測結(jié)果進(jìn)行加權(quán)平均,能夠進(jìn)一步提升最終性能。例如,在 GSM8K 數(shù)據(jù)集上,結(jié)合重排序模型后,ReFT 的準(zhǔn)確率提升了超過 3 個(gè)百分點(diǎn)。
性能優(yōu)化
為提升訓(xùn)練效率,可采用多種工程實(shí)踐。例如,利用混合精度訓(xùn)練(mixed precision training)減少內(nèi)存占用并加速計(jì)算;采用梯度累積技術(shù),在有限 GPU 內(nèi)存下模擬大批次訓(xùn)練效果;優(yōu)化數(shù)據(jù)加載流程,減少 I/O 瓶頸等。以下是一個(gè)混合精度訓(xùn)練的實(shí)現(xiàn)示例:
scaler = torch.cuda.amp.GradScaler()
for epoch in range(num_epochs):
for batch in dataloader:
with torch.cuda.amp.autocast():
outputs = model(batch)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
獎勵欺騙問題的緩解方案包括設(shè)計(jì)更精細(xì)的獎勵函數(shù),如根據(jù)中間步驟的正確性給予部分獎勵;引入專家示范數(shù)據(jù),在訓(xùn)練初期引導(dǎo)模型學(xué)習(xí)正確的推理路徑;實(shí)時(shí)監(jiān)控訓(xùn)練過程中的獎勵分布,及時(shí)發(fā)現(xiàn)并糾正異常的獎勵模式。
總結(jié)
ReFT 方法在數(shù)學(xué)推理任務(wù)上取得了顯著的性能提升。在 GSM8K 數(shù)據(jù)集上,相比 SFT 方法,ReFT 的自然語言 CoT 準(zhǔn)確率提升了 12 個(gè)百分點(diǎn),程序基 CoT 準(zhǔn)確率提升了 17 個(gè)百分點(diǎn);在 SVAMP 數(shù)據(jù)集上,準(zhǔn)確率提升了 10 個(gè)百分點(diǎn)。這些量化評估結(jié)果充分證明了 ReFT 對模型推理能力邊界的擴(kuò)展作用,使其能夠應(yīng)對更復(fù)雜的數(shù)學(xué)問題。
ReFT 對 LLM 微調(diào)范式的創(chuàng)新拓展價(jià)值不容忽視。它為現(xiàn)有微調(diào)技術(shù)體系引入了強(qiáng)化學(xué)習(xí)機(jī)制,豐富了模型的學(xué)習(xí)方式。這一創(chuàng)新不僅提升了模型在數(shù)學(xué)推理任務(wù)上的性能,還為未來微調(diào)方法的研究提供了新的思路與方向,推動了微調(diào)技術(shù)的進(jìn)一步發(fā)展。
局限性分析
盡管 ReFT 取得了顯著成果,但在訓(xùn)練效率方面仍存在瓶頸。強(qiáng)化學(xué)習(xí)階段的訓(xùn)練收斂速度較慢,尤其是在處理大規(guī)模數(shù)據(jù)集時(shí),訓(xùn)練時(shí)間成倍增長。這主要?dú)w因于強(qiáng)化學(xué)習(xí)的試錯特性,模型需通過大量采樣與反饋逐步優(yōu)化策略。潛在的解決方案包括采用更高效的強(qiáng)化學(xué)習(xí)算法,如基于模型的強(qiáng)化學(xué)習(xí)(Model-Based RL),通過學(xué)習(xí)環(huán)境模型減少采樣需求;優(yōu)化采樣策略,提高采樣效率,如采用優(yōu)先經(jīng)驗(yàn)回放(Prioritized Experience Replay)技術(shù),聚焦于信息量大的樣本。
獎勵欺騙問題是 ReFT 面臨的另一挑戰(zhàn)。其深層成因在于獎勵信號的不完全性,當(dāng)模型生成的推理路徑得出正確答案但過程錯誤時(shí),仍可能獲得獎勵,誤導(dǎo)模型學(xué)習(xí)方向。應(yīng)對思路包括設(shè)計(jì)更全面的獎勵函數(shù),綜合考慮路徑的中間結(jié)果、邏輯合理性等多維度信息;引入輔助監(jiān)督信號,如基于中間步驟正確性的獎勵,引導(dǎo)模型學(xué)習(xí)正確的推理過程;在訓(xùn)練過程中增加人類反饋環(huán)節(jié),及時(shí)糾正模型的錯誤推理模式。
未來方向
我們在未來的探索中,可以探索將離線強(qiáng)化學(xué)習(xí)技術(shù)與 ReFT 方法進(jìn)行整合。離線強(qiáng)化學(xué)習(xí)技術(shù)利用預(yù)先收集的數(shù)據(jù)進(jìn)行訓(xùn)練,避免了在線強(qiáng)化學(xué)習(xí)中與環(huán)境交互的高成本和高風(fēng)險(xiǎn)。然而,離線強(qiáng)化學(xué)習(xí)也面臨著數(shù)據(jù)分布偏移、策略退化等挑戰(zhàn)。通過將離線強(qiáng)化學(xué)習(xí)的優(yōu)勢與 ReFT 的在線探索能力相結(jié)合,有望開發(fā)出更加高效、穩(wěn)定的強(qiáng)化學(xué)習(xí)方法。
此外,開發(fā)過程導(dǎo)向的獎勵模型也是一個(gè)重要的研究方向。與傳統(tǒng)的基于最終結(jié)果的獎勵模型不同,過程導(dǎo)向的獎勵模型更加關(guān)注推理過程的質(zhì)量和合理性。例如,可以通過對推理路徑的中間步驟進(jìn)行評估,給予相應(yīng)的獎勵信號,從而引導(dǎo)模型生成更高質(zhì)量的推理路徑。這需要設(shè)計(jì)更加精細(xì)的獎勵模型結(jié)構(gòu)和訓(xùn)練方法,同時(shí)也對數(shù)據(jù)標(biāo)注和特征提取提出了更高的要求。
探索 ReFT 在其他推理任務(wù)領(lǐng)域的遷移應(yīng)用前景也具有重要意義。例如,在邏輯推理、文本蘊(yùn)含、知識問答等領(lǐng)域,ReFT 的強(qiáng)化微調(diào)思路和方法可能同樣能夠發(fā)揮重要作用。通過針對這些任務(wù)的特點(diǎn)和需求,對 ReFT 方法進(jìn)行適當(dāng)?shù)母脑旌蛢?yōu)化,有望進(jìn)一步提升模型在這些領(lǐng)域的推理能力和性能。
記得當(dāng)時(shí)我讀完這篇論文,我深感 ReFT 方法為大語言模型的推理能力提升開辟了全新的路徑。通過強(qiáng)化學(xué)習(xí)機(jī)制,ReFT 使模型能夠擺脫對單一正確推理路徑的依賴,大膽探索多樣化的解題思路。這種創(chuàng)新的微調(diào)范式不僅顯著提升了模型在數(shù)學(xué)問題求解任務(wù)上的性能,還為未來微調(diào)技術(shù)的發(fā)展提供了寶貴的借鑒,要知道高效微調(diào)對 Agent 有多么重要!在去年年底,OpenAI 就推出了相似的 RFT 方法,并于今年 5 月初,RFT 初步落地。感慨,AI 行業(yè)太快了!
總體而言,ReFT 不僅是一項(xiàng)技術(shù)進(jìn)步,更是對大語言模型推理能力邊界的一次勇敢探索。它讓我看到了強(qiáng)化學(xué)習(xí)在提升模型智能水平方面的巨大潛力,也讓我對 AI 的未來發(fā)展充滿期待。