大模型微調非得依賴人類數據嗎?DeepMind:用帶反饋的自訓練更好
如你我所見,大語言模型(LLM)正在改變深度學習的格局,在生成人類質量的文本和解決各種語言任務方面展現出了卓越的能力。雖然業界通過對人類收集的數據進行監督微調進一步提升了在具體任務上的性能,但獲取高質量人類數據卻面臨著重大瓶頸。這對于要解決復雜問題的任務來說尤為明顯,需要大量資源和專業知識。
怎么解決呢?模型生成得合成數據是一種有潛力的替代方案,只要能保證數據的質量,就能實現可擴展性和成本效益。
雖然 LLM 能夠自我評估生成的數據,但在本文中,谷歌 DeepMind 探索了一種更簡單的設置,將外部標量反饋信號用作每個生成樣本的質量指標。
論文地址:https://arxiv.org/pdf/2312.06585.pdf
為了研究在模型生成數據上的訓練,研究者考慮了一種簡單但強大的語言模型自訓練方法,僅需要兩項功能,一是基于模型生成樣本,二是利用評分機制對這些樣本進行評估。
為了確保清晰度和一致性,研究者采用了一種強化自訓練方法 ReST^????,并證明該方法可以將期望最大化(expectation-maximization,EM)用于強化學習。具體來講,ReST^????在期望和最大化步驟之間交替進行。
- 生成(E-step):語言模型為每個輸入上下文生成多個輸出樣本,然后使用二元獎勵過濾這些樣本以收集訓練數據集。
- 改進(M-step):原始語言模型在來自前一個 E-step 的訓練數據集上進行監督微調,然后在下一個 E-step 中使用。
研究者證實,ReST^????及變體在增強各個領域的語言模型方面取得了成功,包括機器翻譯、語義分析、偏好對齊和基礎推理。
此外,以往工作主要將 ReST^????用于相對較小的模型(最高 70 億參數),對于較大模型的可擴展性受限。因此,本文旨在探究模型生成的合成數據與人類生成的數據在以下兩個具有挑戰性但研究較少領域的有效性和可擴展性,這兩個領域分別是競爭水平數學解題(MATH)和代碼生成(APPS)。
實證結果表明,當將 ReST^????用于不同規模的 PaLM 2 模型時,在數學推理和代碼生成任務中實現了顯著的能力改進。與在人類編寫數據上訓練的模型相比,在模型生成的合成數據上微調的模型取得了更大的性能增益。有趣的是,超過了一定數量的 ReST^???? 迭代后,性能會降低,這表明了在少量訓練問題上可能會出現過擬合。
此外,使用 ReST^????微調的模型提升了 pass@k 指標和多數投票性能。這些微調后的模型在相關但 held-out 的基準上也表現出了性能增強,包括數學題(GSM8K 和 Hungarian HS finals)、編碼(HumanEval)和 Big-Bench Hard 任務。
總之,本文研究結果表明,具有反饋的自訓練是減少對人類數據依賴的一種有潛力的方法。
用于強化自訓練的期望最大值(EM)
首先,該研究基于 Dayan 和 Hinton 之前的研究,用語言模型描述了基于 EM 的強化學習框架。具體而言,他們先是定義了一個二進制最優變量 O,使得??(??= 1|??,??)∝??(??(??,??));然后對非遞減函數 ?? : ? → ?+ ,實現最大化觀察??= 1(獲得高獎勵),得到如下公式:
然而,求解上式中的序列 ?? 的和很棘手。因而本文考慮相對于參數 ?? 和變分分布 ??( ??|??) 最大化其 ELBO ??( ????, ??),而不是最大化 log ??(?? = 1; ??)。具體來說:
公式(2)中的 EM 算法在 E-step(Expectation) 和 M-step(Maximization)之間交替進行。
ReST^????:受 EM 框架的啟發,接下來論文討論了 Gulcehre 等人提出的 ReST 方法的簡化版本。為了清楚起見,本文將這種方法稱為 ReST^????,它將 RL pipeline 中的數據收集 (E-step) 和策略優化 (M-step) 進行解耦。如算法 1 所示:
生成(E-step):在此步驟中,該研究通過從當前策略 ???? 中采樣輸出序列來生成數據集
。在這里,輸入是從原始數據集
中重新采樣的。然后使用二元獎勵函數 ??(??, ??) 對
中的輸出序列進行評分。
改進(M-step):在第 ??步迭代中,該研究使用 E-step 中的新數據集來微調策略 ????。不同于 Gulcehre 的研究,他們微調基本預訓練語言模型,以最大限度地減少特定于任務的過度擬合并最大限度地減少與基本模型的偏差。為了進行微調,該研究最小化獎勵加權負對數似然損失
。一旦策略得到改進,就可以再次創建質量更好樣本的新數據集。
實驗和分析
本文進行實驗的主要目標是回答以下問題:
- 與人類生成的數據進行微調相比,ReST^????的效果如何?
- 需要多少次迭代才能獲得最佳性能?ReST^????多長時間會導致訓練集過度擬合?
- ReST^????如何影響 pass@k 和多數投票表現?
- 如果用戶在特定任務上使用模型生成的數據進行微調,是否會遷移到其他任務上?在廣泛的任務中評估本文的微調模型時,與基本模型相比,性能是否會下降?
- 大約需要多少輸入數據才能從 ReST^???? 獲得大部分性能提升?ReST^????的一次迭代是否足夠?
該研究使用 PaLM 2 模型和 Google Cloud 上的公共 API 進行實驗,包括 PaLM 2-S (Bison)、PaLM 2-S* (Codey) 和 PaLM 2-L (Unicorn)。訓練數據集采用 MATH 數據集和 APPS 數據集。
圖 2 和圖 3 分別顯示了 ReST^????在 MATH 和 APPS 數據集上訓練的性能??梢缘贸?MATH 受益于 ReST^???? 的多次迭代,無論是在 MATH 測試集上的性能還是遷移到 GSM8K 方面。另一方面可以看到 APPS 的大部分收益來自第一次迭代,而執行更多次迭代會導致 APPS 和 HumanEval 的性能下降。
訓練和測試性能的差距。圖 4 顯示,雖然訓練集性能隨著 ReST^????迭代次數線性增加,但測試集性能卻沒有。對于 MATH,第一次迭代后測試性能改進很小,而對于 APPS,在第二次迭代中觀察到性能回歸。該研究猜測性能的回歸可能是由于過度擬合造成的。由于 APPS 數據集的大小約為 MATH 數據集的三分之一,因此它更容易受到此問題的影響。
圖 5 顯示了 Palm-2-L 模型在 pass@K 指標上的性能。結果顯示,微調后獲得的 ReST^???? 模型對于所有 K 值都更強,其中性能差距通常在 K=1 時最大。