大模型生成提速2倍!單GPU幾小時搞定微調,北大數院校友共同一作丨開源
本文經AI新媒體量子位(公眾號ID:QbitAI)授權轉載,轉載請聯系出處。
只需給大模型“加點小零件”,推理速度立刻提升2倍!
不需要額外訓練一個模型,也不需要對計算硬件做優化,單張A100最快幾小時就能微調完成。
這項新研究名叫Medusa(美杜莎),來自普林斯頓、UIUC、CMU和康涅狄格大學,FlashAttention作者Tri Dao也在其中。
目前,它已經成功部署到伯克利70億參數的“駱馬”Vicuna中,后續還會支持其他大模型,已經登上GitHub熱榜:
但其實,在這種方法推出之前,業界并非沒有大模型推理加速方法,主流的就是DeepMind推出的投機采樣(speculative decoding)。
相比這種方法,Medusa有什么不一樣的地方?
投機采樣的2個“bug”
要想加速大模型推理,需要先知道究竟是什么“限制”了它的速度。
相比計算量的增加,大模型推理速度更容易受到內存帶寬的影響(memory bound)。
這是因為,大模型由于參數量巨大、遠超緩存容量,因此推理時需要先把權重從外部內存(顯存)讀取一次到緩存中,這個過程受內存帶寬限制,速度通常很慢。
因此,模型做批量推理(batch inference)時,一次處理100個tokens和一個tokens時間上區別不大。
基于這個特點,DeepMind去年11月想出了一個名叫投機采樣的神奇操作——
訓練一個更小的模型(draft模型),給大模型提前生成一批“候選詞”,相比于讓大模型自己“思考”生成,直接做“選擇”就好。
由于小模型生成速度比大模型快好幾倍,一旦大模型覺得小模型已有的詞“可用”,就直接拿來,不用自己再緩慢生成一遍。
這個過程,有點像是輸入法的聯想詞候選,在我們(大模型)想好下一個詞用什么之前,輸入法(小模型)先給列出一些備選項:
要是看到覺得不錯,就從中選一個用;要是覺得生成的都不行,就pass掉自己重新打。
這種投機采樣方法確實取得了顯著成效,甚至能輕輕松松在M2 Ultra上以高精度跑340億參數LLaMA大模型。
BUT,這種方法存在兩個問題。
一方面,給大模型找個生成“候選詞”的draft小模型,沒那么容易。
這個小模型可不是隨便抓個生成模型就能用,除了接口統一、概率分布接近等要求,生成質量也不能比大模型差太多。
對于Meta發布的LLaMA這種模型可能還好,既有幾百億參數的大模型版本,又有幾十億參數的小模型版本,可以把參數量更小的版本拿來當draft模型使用。
但對于其他開源大模型,這種方法就不太適用了,自己去搭建訓練一個小模型,不僅時間成本更高,生成效果可能還不達預期。
另一方面,雙模型的組合,使得后續要想做系統調優變得更復雜。
這是因為,相比于大模型自身是一個系統,新增加的draft模型相當于又引入了一個系統。
這樣會導致模型部署起來更復雜,包括額外的網絡傳輸、不同的硬件條件都需要考慮到,在做計算優化時難度也會進一步提升。
為了解決這些問題,Medusa出現了。
不用小模型,加幾個“頭”就行
Medusa(美杜莎,一種長有多個頭的妖怪)是一種新的大模型推理加速方法。
相比投機采樣,它選擇直接給Transformer大模型多加幾個解碼頭(decoding heads),每個頭都是一個單層前饋網絡。
這幾個多出來的解碼頭,可以讓大模型直接一次多生成幾個詞,而不是“擠牙膏式”一個一個生成。
生成準確率也還可以,在預測“下一個詞的下一個詞”時,Medusa準確率達到了60%,還在不斷優化中。
隨后,結合樹狀注意力機制(tree-based attention mechanism)并行驗證這些詞,從而實現推理加速。
基于Medusa,Vicuna的70億、130億和330億參數大模型推理速度,均有了1.9倍以上的效率提升:
針對70億參數的模型,研究者們還在不同任務上測試了一下加速效果,顯示最高在代碼生成上有2.15倍的速度提升。
最關鍵的是,用上Medusa后,并不需要將整個大模型重新訓練一遍。
相比之下,它可以和大模型一起訓練,只需要凍結大模型的參數就行,甚至單個GPU就能搞定。
由于不增加額外的模型,對于分布式推理也很友好。
作者介紹
這項研究有兩位共同一作。
共同一作蔡天樂,普林斯頓大學博士生,研究方向包括優化、表示學習、架構設計等,本科畢業于北京大學數學科學學院,獲得應用數學和計算機科學雙學位。
共同一作Yuhong (Jesse) Li,伊利諾伊大學香檳分校(UIUC)博士生,研究方向是高效機器學習,本科畢業于北京郵電大學。
此外,這項研究也有FlashAttention作者、斯坦福博士Tri Dao的參與。
FlashAttention是一種能加快注意力并減少內存占用的方法,相比PyTorch標準注意力實現,最高能提速9倍。
GitHub地址:https://github.com/FasterDecoding/Medusa
研究地址:https://sites.google.com/view/medusa-llm