一次預測多個token,Meta新模型推理加速3倍,編程任務提高17%
“預測下一個token”被認為是大模型的基本范式,一次預測多個tokens又會怎樣?
Meta AI法國團隊推出“基于多token預測的更快&更好大模型”。
多token預測模型,在編程類任務上表現(xiàn)尤其突出。
與單token預測相比,13B參數(shù)模型在HumanEval上多解決了12%的問題,在MBPP上多解決了17%。
小型算法推理任務上,多token預測也在分布外泛化方面帶來了令人印象深刻的收益。
不過在自然語言任務上,多token預測方法并不能顯著提高7B模型在數(shù)學選擇題上的表現(xiàn)了。
另外一個好處是,即使batch size較大,使用4-token預測訓練的模型,推理速度也可提高3倍。
多token預測更適合編程
具體來說,團隊設計了一種新的多token預測架構,通過n個獨立的輸出頭并行預測n個未來token。
使用大量文本數(shù)據(jù)進行模型訓練,包括代碼和自然語言數(shù)據(jù)集。
再通過實驗比較多token預測和單token預測在多個下游任務上的性能。
為啥多token預測在編程任務和小型算法推理任務上提升更明顯?
團隊猜測可能有兩個原因:
第一,編程語言的邏輯結構更嚴謹,知識的內在聯(lián)系更緊密。一個關鍵節(jié)點可能影響到后續(xù)整個代碼塊的走向。多Token預測能更好捕捉這種長距離依賴。
第二,相比自然語言,編程語言的詞匯量更小。因此即便每次預測多個Token,難度也沒那么大。反而能迫使模型從局部細節(jié)中抽身,著眼全局優(yōu)化。
除了在token層面的實驗,團隊還在更細粒度的字節(jié)級模型上做了嘗試。
他們發(fā)現(xiàn),用8字節(jié)預測替代下一個字節(jié)預測后,模型在MBPP上的Pass@1指標暴增67%,在HumanEval上也提升了20%。
而且推理速度還能再快6倍,簡直不要太香。
對于背后原理,團隊認為多token預測緩解了訓練時Teacher Forcing和推理時自回歸生成之間的分布差異。
也就是說,在訓練的時候,模型看到的都是標準答案,生成的時候卻得靠自己。好比人類在家做練習冊時有答案,考試時卻啥也沒有,就會不適應。
而多token預測相當于訓練時就逼著模型多想幾步,這樣到了考場上,才能應對自如。
從信息論的角度,團隊還給出了一個更精確的論證。
?
傳統(tǒng)的下一個Token預測,目標是最小化當前位置的信息熵。而2-Token預測實際上最小化的是當前和下一位置的信息熵之和。
數(shù)學推導表明,后者其實隱含了更大的互信息權重,也就是更看重當前Token和未來Token的相關性。這就是為什么多Token預測更”有遠見”。
不過在這篇論文中,還有幾個未解決的問題。
?
比如沒有探討如何自動選擇最佳的預測token數(shù)量n,作者提出,未來可以研究使用損失權重調整或動態(tài)調整n來解決最佳n的選擇問題。
此外最佳的詞表大小也可能與單token預測時不同。
總之,看過這篇論文之后,大家都更期待Llama-4了。
論文地址:???https://arxiv.org/abs/2404.19737??
本文轉自 量子位 ,作者:量子位
