阿里通義發(fā)布并行計算新策略:1.6B等效4.4B,內(nèi)存消耗驟降95%
既能提升模型能力,又不顯著增加內(nèi)存和時間成本,LLM第三種Scaling Law被提出了。
對于1.6B模型,能實現(xiàn)性能接近4.4B模型,內(nèi)存占用僅為后者的1/22,延遲增加量為1/6。
并且可直接應用于現(xiàn)有模型(如Qwen-2.5),無需從頭訓練。
這就是阿里通義團隊提出的PARSCALE。
目前LLMs的優(yōu)化主要有兩種思路:參數(shù)擴展(如GPT-4)和推理時間擴展(如DeepSeek-R1),但會增加內(nèi)存和時間成本。
阿里通義團隊提出的新范式受CFG(無分類器引導)雙路徑推理機制的啟發(fā)。
他們將CFG的并行思想從 “生成階段的推理優(yōu)化” 擴展為 “訓練和推理全流程的「計算縮放」”。
讓我們來扒一扒技術細節(jié)。
將CFG的并行思想擴展到計算縮放
PARSCALE對于CFG雙路徑的靈感遷移
CFG 通過同時運行有條件生成(輸入提示詞)和無條件生成(不輸入提示詞)兩條路徑,再通過加權(quán)平均融合結(jié)果,提升生成質(zhì)量(如文本相關性、圖像細節(jié)精準度)。
其核心在于利用并行計算(兩次前向傳播)增強模型決策的多樣性和準確性,而無需增加模型參數(shù)。
研究人員觀察到CFG的有效性可能源于計算量的增加(兩次前向傳播),而非單純的條件引導。
由此提出假設:并行計算的規(guī)模(如路徑數(shù)量)可能是提升模型能力的關鍵因素,而非僅依賴參數(shù)規(guī)模或推理時間的串行擴展(如生成更多token)。
CFG用2條并行路徑提升性能,PARSCALE則將路徑數(shù)量擴展為P條(如P=8),并通過可學習的輸入變換和動態(tài)聚合,使并行計算成為一種可擴展的 “計算縮放” 范式。下圖展示了PARSCALE方法。
PARSCALE改進的并行計算框架
1、輸入層:可學習的多路徑輸入變換
核心改進是將CFG的固定雙路徑擴展為P條可學習的并行路徑,每條路徑通過可訓練的前綴嵌入生成差異化輸入。
- 前綴嵌入生成:為每個并行路徑引入可訓練的前綴向量(維度與輸入嵌入一致),拼接在原始輸入前,形成路徑專屬輸入。
- KV緩存區(qū)分:在Transformer的注意力層中,不同路徑的鍵(K)和值(V)緩存相互獨立,確保各路徑的計算互不打擾,增強輸出多樣性。
2、計算層:并行前向傳播
- 并行執(zhí)行:將P個差異化輸入同時輸入模型,利用GPU的并行計算能力,一次性完成P路前向傳播,生成P個輸出流。
- 效率優(yōu)勢:通過批量矩陣運算實現(xiàn)P路并行,計算效率隨P線性增長,共享模型主體參數(shù),僅增加前綴嵌入等少量可訓練參數(shù)。
3、輸出層:動態(tài)加權(quán)聚合
通過多層感知機(MLP)動態(tài)計算各路徑輸出的聚合權(quán)重,替代 CFG 的固定權(quán)重機制:若某路徑輸出與當前輸入語義匹配度高,MLP 會為其分配更高權(quán)重。
PARSCALE更高效
PARSCALE vs. 參數(shù)擴展
當P=8時,1.6B參數(shù)模型在HumanEval的性能(Pass@1=39.1%)接近4.4B參數(shù)模型(Pass@1=45.4%),但內(nèi)存占用僅為后者的1/22,延遲增加量為1/6。
在GSM8K數(shù)學推理任務中,P=8使1.8B模型性能提升34%(相對基準),顯著高于參數(shù)擴展的增益。
兩階段訓練策略
階段1:用傳統(tǒng)方法預訓練模型至收斂(1Ttokens)。
階段2:凍結(jié)主體參數(shù),僅訓練前綴嵌入和聚合權(quán)重(20Btokens,占總數(shù)據(jù)的 2%)。
P=8模型在GSM8K上提升34%,且與從頭訓練效果相當,證明少量數(shù)據(jù)即可激活并行路徑的有效性。且該策略使訓練成本降低約 98%
適配現(xiàn)有模型
研究團隊在Qwen-2.5-3B模型上進行持續(xù)預訓練和參數(shù)高效微調(diào)(PEFT),僅調(diào)整前綴和聚合權(quán)重。
結(jié)果顯示,在代碼生成任務(HumanEval+)中PEFT 方法使Pass@1提升15%,且凍結(jié)主體參數(shù)時仍有效,證明動態(tài)調(diào)整 P 的可行性。
PARSCALE通過可學習的多路徑輸入、動態(tài)聚合權(quán)重、全流程并行優(yōu)化,將CFG的 “雙路徑啟發(fā)” 升級為一種通用的計算縮放范式。
感興趣的朋友可到官方查看更多細節(jié)~
論文鏈接:https://arxiv.org/abs/2505.10475
代碼地址:https://github.com/QwenLM/ParScale
參考鏈接:https://x.com/iScienceLuvr/status/1923262107845525660