AI慢思考蒸餾進快思考,Llama2躍升至GPT-4水平,不寫過程也能做對題
《思考快與慢》中人類的兩種思考方式,屬實是被Meta給玩明白了。
研究人員通過把AI的“慢思考”結果蒸餾進“快思考”,讓Llama2表現提升了257%,變得比GPT4還能打,同時還能降低推理成本。
這里的快慢兩種思考方式,指的就是2002年諾貝爾經濟學獎得主丹尼爾·卡尼曼推廣的系統1和系統2——
簡單說,系統1是簡單無意識的直覺,速度更快;系統2則是復雜有意識的推理,準確性更強。
Meta所做的“蒸餾”,就是用系統2生成數據,然后對用系統1推理的模型進行微調。
有網友看了后表示,這種模式和人類很像,一旦解決了一個難題,再解決(相似的問題)就變得簡單了。
將系統2蒸餾到系統1
對于大模型而言,模仿人類的“系統2”的方式有很多種,在模型中所處的環節也不盡相同,這里作者一共研究了四種:
- CoT,即Chain of Thought,思維鏈,從提示詞入手讓模型逐步思考;
- S2A,即System 2 Attention,由Meta自己提出,直接修改了模型的注意力機制,屏蔽與任務無關的信息;
- RaR,即Rephase and Respond,先對問題進行重新表述,再根據重述后的問題生成答案;
- BSM,即Branch-Solve-Merge,將復雜任務分解為多個分支,針對每個分支獨立生成評分,再將各個分支的評分綜合。
但從整體流程上看則是殊途同歸,各種“系統2方法”都會在未標注數據集上生成推理結果。
在這過程當中,模型會在給出結果的同時生成詳細的中間推理步驟,但研究人員只保留最終的輸出結果。
然后就得到了輸入-系統2輸出的數據對,可以視為一種無監督的“偽標簽”,將這些數據對收集起來,就形成初步的蒸餾數據集。
當然了,這步得到的數據還不能直接拿來微調系統1模型,需要進行過濾以確保其擁有足夠高的質量。
過濾的具體依據,是一致性和魯棒性。
一致性篩選當中,對每個輸入樣本,都會用系統2模型采樣生成多個輸出,然后通過多數投票等方法進行比較,如果大多數都一致,則認為該輸出是可靠的;
魯棒性篩選是對一個輸入樣本進行適當的擾動,如改變無關細節、調整詞序等,然后觀察系統2模型在擾動前后的輸出是否一致。
篩選后的高質量蒸餾數據,就可以對系統1模型進行無監督微調了。
微調過程可以看作是一種知識蒸餾,但又與與傳統的知識蒸餾不同,這里兩種系統使用的是同一個基礎模型。
系統1模型的目標是直接學到系統2模型的輸出行為,而不是中間的復雜推理過程,在后續推理時也不需要執行系統2的推理步驟,而是直接生成輸出。
但從輸出質量上來看,表現卻能接近系統2模型,也就是實現了系統2能力向系統1的轉移。
那么,為什么要專門收集數據去微調系統1模型,而不直接用系統2模型推理呢,作者也給出了解釋。
道理其實很簡單,從系統2的另一個名字“慢系統”當中,很容易就能看出答案:
因為系統2的速度慢,在實時交互、移動設備部署等場景下,模型的延遲可能是無法接受的。
另外,由于需要輸出完整的推理過程,系統2輸出的token長度也是系統1的數百倍。
就像開頭那位網友說的,系統2把復雜的推理解決了,再將數據喂給系統1,問題對其而言也會變得容易。
從表現上看,這樣的模式也確實讓系統1模型的表現大幅進步,甚至超過了真·系統2模型。
讓Llama2超越GPT-4
針對前面四種不同的系統2方法,研究人員分別使用不同的數據集,在不同的任務上進行了測試。
針對BSM方法,作者采用的數據集是Open Assistant 2和MT-bench,評估了模型作為“評判者”時的表現。
可以看到,在兩個數據集中,Llama-2的表現(人類一致性)分別從32.0%和28.1%,提高到了58.4%和72.4%,最高增幅達到了257%,比CoT方法更加有效。
而且,微調后的模型均超過了系統1版的GPT-4,甚至達到了GPT-4配合CoT的水準。
同時(改變選項位置后的)不一致性也大幅降低,而且和系統2相比,Token數量少到幾乎可以忽略不計。
同時針對MT-Bench不同的子類任務,作者也分別分析了各種方法的人類一致性。
接下來是S2A方法,它主要解決的是模型偏見問題,因此評估時采用了帶偏見的TriviaQA任務。
結果蒸餾后的準確率達到81.3%,超過了原始S2A的76%,生成的token數量也從147個減少到了56個。
RaR的測試目標則是完成一些推理任務,這里作者測試了Last letter concatenation和Coin flip。
在Letter任務中,蒸餾后的系統模型準確率從30%飛升到了98%,也超過了系統1自蒸餾的69.5%,同時也優于原始的RaR方式。
而在Coin flip任務里,蒸餾后的準確率達到 75.69%,也與接近2-步原始RaR的77.2%接近,但生成的token數量大幅減少。
不足的一點是,CoT的蒸餾效果與另外三種大相徑庭,作者發現,在數學推理任務上,CoT的推理能力很難遷移到系統1當中。
在GSM8K數據集上,蒸餾后的模型在k=1時準確率僅為7.13%,k=10時也只有7.35%,甚至不如沒蒸餾之前的版本。
所以,作者認為,接下來的研究目標是進一步明確這種蒸餾的應用場合,找到更類似于人類學習的方式。
論文地址:
???https://arxiv.org/abs/2407.06023??
本文轉自 量子位,作者:量子位
