算數能力接近滿分!新加坡國立大學發布Goat,僅用70億參數秒殺GPT-4,起步支持16位數乘除法
大規模語言模型雖然在各大自然語言處理任務上都展現了優越的性能,不過算術類題目仍然是一大難關,即便是當下最強的GPT-4也很難處理基礎運算的問題。
最近,來自新加坡國立大學的研究人員提出了一個專供算術的模型山羊Goat,在LLaMA模型基礎上微調后,實現了顯著優于GPT-4的算術能力。
論文鏈接:https://arxiv.org/pdf/2305.14201.pdf
通過對合成的算術數據集進行微調,Goat在BIG-bench算術子任務上實現了最先進的性能,
Goat僅通過監督微調就可以在大數加減運算上實現近乎完美的準確率,超越了之前所有的預訓練語言模型,如Bloom、OPT、GPT-NeoX等,其中零樣本的Goat-7B所達到的精度甚至超過了少樣本學習后的PaLM-540
研究人員將Goat的卓越性能歸功于LLaMA對數字的一致性分詞技術。
為了解決更有挑戰性的任務,如大數乘法和除法,研究人員還提出了一種方法,根據算術的可學習性對任務進行分類,然后利用基本的算術原理將不可學習的任務(如多位數乘法和除法)分解為一系列可學習的任務。
通過全面的實驗驗證后,文中提出的分解步驟可以有效地提升算術性能。
并且Goat-7 B可以在24 GB VRAM GPU上使用LoRA高效訓練,其他研究人員可以非常容易地重復該實驗,模型、數據集和生成數據集的python腳本即將開源。
會算數的語言模型
語言模型
LLaMA是一組開源的預訓練語言模型,使用公開可用的數據集在數萬億個token上進行訓練后得到,并在多個基準測試上實現了最先進的性能。
先前的研究結果表明,分詞(tokenization)對LLM的算術能力很重要,不過常用的分詞技術無法很好地表示數字,比如位數過多的數字可能會被切分。
LLaMA選擇將數字切分為多個token,確保數字表示的一致性,研究人員認為,實驗結果中表現出的非凡算術能力主要歸功于LLaMA對數字的一致性分詞。
在實驗中,其他微調后的語言模型,如Bloom、OPT、GPT-NeoX和Pythia,無法與LLaMA的算術能力相匹配。
算術任務的可學習性(Learnability of Arithmetic Tasks)
之前有研究人員對使用中間監督解決復合任務(composite task)進行了理論分析,結果表明這種任務是不可學習的,但可以分解為多項式數量的簡單子任務。
也就是說,不可學習的復合問題可以通過使用中間監督或逐步思維鏈(CoT)來學習。
在此分析基礎上,研究人員首先對可學習和不可學習任務進行實驗分類。
在算術計算的背景下,可學習任務通常是指那些可以成功訓練模型以直接生成答案的任務,從而在預定義數量的訓練epochs內實現足夠高的精度。
不可學習的任務是那些即使經過廣泛訓練,模型也難以正確學習和生成直接答案的任務。
雖然任務可學習性變化背后的確切原因尚不完全清楚,但可以假設這與基本模式的復雜性和完成任務所需的工作記憶大小有關。
研究人員通過在簡化的合成環境中專門針對每個任務微調模型來實驗檢查這些任務的可學習性。
可學習的和不可學習的任務
任務分類的結果也與人類的感知相同,通過實踐,人類可以在腦海中計算兩個大數字的加法和減法,無需手算的情況下,可以直接從左(最高有效數字)到右(最低有效數字)寫下最終的數字答案。
不過心算解決大數乘法和除法是一項具有挑戰性的任務。
還可以觀察到,上述對任務的分類結果與GPT-4的性能也一致,特別是GPT-4擅長為大數加法和減法生成直接答案,當涉及到多位乘法和除法任務時,準確性會顯著下降。
像GPT-4這樣強大的模型無法直接解決不可學習的任務,也可能表明,即使經過廣泛的訓練,為這些任務生成直接答案也是極具挑戰性的。
值得注意的是,對于LLaMA來說是可學習的任務可能不一定對于其他LLM來說是可學的。
此外,并非所有被歸類為不可學習的任務對模型來說都是完全不可能學習到的。
例如,兩位數乘兩位數被認為是一項不可學習的任務,但如果訓練集中包含所有可能的2位數乘法枚舉數據的話,模型仍然可以通過過擬合訓練集來直接生成答案。
不過整個過程需要近10個epoch才能達到90%左右的準確率。
而通過在最終答案之前插入文中提出的CoT,該模型可以在1個epoch的訓練后就可以在兩位數乘法中實現相當不錯的精度,也與之前的研究結論一致,即中間監督的存在有助于學習過程。
加法與減法
這兩個算術操作是可學習的,僅通過有監督微調,模型就表現出了準確生成直接數字答案的非凡能力。
盡管模型只是在非常有限的加法數據子集上進行了訓練,但從模型在未見過的測試集上實現了近乎完美的準確率上可以看出來,模型成功地捕獲了算術運算的基本模式,并且無需使用CoT
乘法
研究人員通過實驗驗證了n位數乘1位數的乘法是可學習的,而多位數乘法則無法學習。
為了克服這個問題,研究人員選擇在生成答案之前對LLM進行微調以生成CoT,將多位數乘法分解為5個可學習的子任務:
1. 抽取(extraction),從自然語言指令中抽取算術表達式
2. 拆分(split),將兩者中較小的數拆分為place值
3. 展開(expansion),基于分配性展開求和
4. 乘積(product),同時計算每個乘積
5. 逐項相加(adding term by term),將前兩項相加,復制其余項,得到最終和
其中每個任務都是可學習的。
除法
類似地,可以通過實驗觀察到n位數除以1位數是可以學習的,而多位數除法是不可學習的。
研究人員利用改進慢除法的遞推方程,設計了一個全新的思維鏈提示。
主要思想是從被除數中減去除數的倍數,直到余數小于除數。
數據集
文章中設計的實驗為兩個正整數的加法和減法,每個正整數最多包含16位數字,并且減法運算的結果可能是負數。
為了限制生成的最大序列長度,乘法的結果為12位以內的正整數;兩個正整數的除法中,被除數小于12位,商值6位數以內。
研究人員使用Python腳本合成了一個數據集,生成了大約100萬個問答對,答案包含提出的CoT以及最終的數字輸出,所有數字都是隨機生成的,可以保證重復實例的概率非常低,不過小數字可能會被多次采樣。
微調
為了使該模型能夠基于指令解決算術問題,并促進自然語言問答,研究人員使用ChatGPT生成了數百個指令模板。
在指令調整過程中,從訓練集中為每個算術輸入隨機選擇一個模板,并微調LLaMA-7B,類似于Alpaca中使用的方法。
Goat-7B可以在24GB VRAM GPU上使用LoRA進行微調,在A100 GPU上僅花費大約1.5小時即可完成10萬樣本的微調,并實現近乎完美的精度。
實驗結果
比較Goat和GPT-4在大量乘法和除法方面的性能似乎不公平,因為GPT-4會直接生成答案,而Goat則依賴于設計的思維鏈,所以在GPT-4評估時還在每個提示的結尾加入「Solve it step by step」
不過可以觀察到,雖然GPT-4在某些情況下,長乘法和除法的中間步驟錯了,但最終答案仍然是正確的,也就意味著GPT-4并沒有利用思維鏈的中間監督來提高最終輸出。
最終從GPT-4的解決方案中確定了以下3個常見錯誤:
1. 對應數字的對齊
2. 重復數字
3. n位數乘以1位數的中間結果錯誤
從實驗結果中可以看插到,GPT-4在8D+8D和16D+16D任務上表現相當好,但在大多數16D+8D任務上的計算結果都是錯誤的,盡管直觀上來看,16D+8D應該比16D+16D相對容易。
雖然造成這種情況的確切原因尚不清楚,但一個可能的因素可能是GPT-4不一致的數字分詞過程,使得兩個數字之間很難對齊.