小紅書提出大模型推理加速算法 HASS 刷新 SOTA
在大模型推理領域,投機采樣是一種被廣泛使用的無損加速算法。近期一些投機采樣的工作將大模型的上下文信息(例如 hidden states 和 KV cache)引入草稿模型,可以充分利用大模型的知識來提升加速比,但這類算法也會帶來訓練和解碼的上下文不一致問題。此外,我們也發現現有算法在訓練和解碼的目標上也存在一定的不一致現象。小紅書中臺算法團隊提出的 HASS 算法在目標和上下文上對齊了草稿模型的訓練和解碼階段,達到了普通推理速度的 2.81~4.05 倍,相比 SOTA 方法 EAGLE-2 提升 8%~20%,相關技術已應用在小紅書實際業務場景中。
論文地址
https://arxiv.org/pdf/2408.15766
01 背景
生成式大語言模型(LLMs)在各種任務上表現出令人驚嘆的能力。然而,由于其固有的自回歸解碼機制,人們難以在這些模型上高效推理,這限制了它們在時間敏感場景中的應用。投機采樣技術通過利用額外的資源來增加并發性,提供了一種大模型推理加速的解決方案。
投機采樣(Speculative Sampling)
投機采樣是一種先起草再驗證的解碼范式。在每一步解碼時,先高效地生成多個草稿 token,再使用目標 LLM 并行地驗證這些 token 來加速推理。 表示想要加速推理的目標 LLM,
表示基于前綴
從目標 LLM 生成下一個 token 的條件概率分布(簡寫為
)。
表示一個更高效的草稿模型,
表示基于前綴
從草稿模型生成下一個 token 的條件概率分布(簡寫為
)。投機采樣分為如下3步:
1.使用更高效的草稿模型 來生成
個草稿 token。
2.使用目標 LLM 來并行地驗證這些草稿 token 以及它們從
被生成的概率,接受能使得輸出分布和
一致的所有草稿 token。
3.如果某個草稿 token 被拒絕后,從修正后的分布中采樣一個額外的 token 替代它;如果所有的草稿 token都被接受,額外增加一個新的 token。
具體驗證過程如下:從 中采樣一個草稿 token
,如果
則接受
;否則將以
的概率拒絕并從修正分布
中重新采樣一個 token 接受。
經證明,對于任意的 和
,如此得到的 token 總是與目標 LLM 分布一致。目標 LLM 的每一次前向推理至少產生一個新的 token,而至多產生
個新的 token,生成的個數取決于目標 LLM 和草稿模型的對齊程度。
投機采樣的實際性能取決于兩個因素:草稿模型的解碼成本及其與目標 LLM 的對齊程度。為了獲得與目標 LLM 高度對齊的高效草稿模型,近期的工作提出利用目標 LLM 的上下文信息。例如,EAGLE 使用目標 LLM 的 hidden states 作為草稿模型的輸入特征。然而,這些方法在訓練和解碼階段引入了不一致的上下文,如圖 2 所示。在訓練期間,草稿模型總是能獲取到目標 LLM 在先前時間步的 hidden states。但在解碼期間,草稿模型卻無法獲取到未被驗證時間步的目標 LLM 的 hidden states,這導致了訓練和解碼階段的上下文不一致。這一問題可以看作是投機采樣中在特征層面的 exposure bias。
訓練和解碼階段之間還存在目標上的不一致。在解碼階段,草稿模型的目標是生成目標 LLM 會賦予高概率的 token。在這種情況下,草稿模型應更關注于召回這些高概率 token,而對它們之間的具體順序則可以稍微放松。另外,大部分 LLM 在應用時采取核采樣或 top-k 采樣。在這些解碼策略中,高概率 token 對輸出起著更重要的作用。因此,為了獲得高效的草稿模型,它的訓練目標應考慮到解碼階段的這些特性。據我們所知,現有的涉及訓練草稿模型的投機采樣方法普遍忽視了這些解碼目標。
02 方法
為解決上述的訓練和解碼階段不一致問題,我們提出了協調投機采樣(HASS),旨在通過訓練階段學習協調的表征來解決上述問題。我們的方法包含兩部分:(1)為了讓草稿模型在訓練階段感知到解碼目標,HASS 將推薦系統中的排序蒸餾思想擴展到投機采樣,即協調目標蒸餾;(2)為了解決訓練和解碼間的上下文不一致,我們提出了一種多步的對齊訓練策略,即協調上下文對齊。結合這兩部分,HASS 顯著提高了 LLM 的推理速度。在無需額外推理開銷的情況下,也保持了草稿模型訓練的高效。
協調目標蒸餾(Harmonized Objective Distillation)
HASS 通過引入推薦系統中的排序蒸餾思想,優先考慮草稿模型解碼時更重要的一些 token。具體來說,排序蒸餾的目標是訓練學生模型,使其對教師模型中排名靠前的項賦予更高的排序。在投機采樣中,草稿模型是學生模型,而目標 LLM 是教師模型。具有類似特性的草稿模型在解碼階段將獲得更高的接收率。設 K 個概率最高的 token 組成的集合為 ,其中
代表整個詞匯表。HASS 在訓練時使用以下的 Top-K 蒸餾損失:
其中 和
分別表示目標 LLM 和草稿模型預測下一個詞的條件概率分布。在結合 EAGLE 時,訓練階段可以從目標 LLM 的 hidden states 中獲取
,這意味著結合 Top-K 損失訓練有著和 EAGLE 一樣的訓練效率。
協調上下文對齊(Harmonized Context Alignment)
HASS 采用了多步的對齊訓練策略,使草稿模型在訓練和解碼階段的上下文保持一致。具體來說,HASS 將訓練過程分為 n 步,使草稿模型能夠利用與解碼階段一致的上下文特征。過程如下:
- 第一步與 EAGLE 的訓練相同。在時間步 t+1,草稿模型以目標LLM的特征
作為輸入并生成草稿模型特征
。這一步中,注意力掩碼與因果掩碼一致,不做修改。
- 第二步利用了來自第一步的特征。在時間步 t+1 的自注意力機制中,使用
來生成 query。key 和 value 由
生成,其中
表示拼接操作,
表示早于時間步 t 的特征。注意力掩碼被修改以確保
看到的前一個特征始終是
,如圖 3 中的“HASS Training Step 2“所示。
- 對于第 j 步(j ≥ 3),前一步生成的特征
用于生成時間步 t+1 的query,而 key 和 value 由
生成。
HASS 的訓練開銷是 EAGLE 的 n 倍,但解碼開銷不變。后續實驗證明,HASS 的加速效果在 n 值較小時就會收斂,因此是訓練高效的,具體實現請參考論文的附錄部分。
03 實驗
主要實驗
如表 1、2 所示,HASS 在所有的數據集和目標 LLM 上都表現出了最高的接受長度和最優的加速比。大部分方法在 HumanEval 數據集上加速效果最好,因為代碼生成任務中的固定模版對于草稿模型更易生成從而加速。盡管 PLD 和 Lookahead 無需訓練,但是它們的性能都顯著弱于 EAGLE、EAGLE-2 和 HASS。
協調目標蒸餾的消融實驗
我們改變了 Top-K 損失的 K 和權重,結果如圖 4 所示。使用 Top-K 損失訓練(權重大于 0)時,總是能提升草稿模型的接受長度。當 K 值很小時(K=1)會導致性能下降,可能是因為草稿模型過度關注概率最高的 token 而忽視了其他潛在 token。在 K=5 時,草稿模型的接受長度最大。
我們還嘗試了更多關注高概率 token 的損失函數以替換 Top-K 損失,結果如表 3 所示。BiLD 損失在 T=0 時表現最好,Top-K 損失在 T=1 時表現最好。總體上,Top-K 損失的表現最好。
協調上下文對齊的消融實驗
我們改變了協調上下文對齊的對齊步數,將用 Top-K 損失訓練后的 EAGLE-2 權重作為基準,結果如表 4 所示。在不使用協調上下文對齊時(EAGLE-2+Top-K),草稿模型的效果最差。用 3 或 4 步協調上下文對齊訓練的草稿模型總體上能獲得最優的接受長度。當對齊步數增加到 5 步時,接受長度反而會下降,這可能是因為草稿模型的能力有限,當過度關注后幾步的 token 生成時就會導致在前幾步的預測精度下降。
我們畫出了 HASS 和 EAGLE-2 在每一步生成token時的接受率曲線,如圖 5 所示。可見在后幾步生成 token 時,HASS 的接受率顯著高于 EAGLE-2,驗證了協調上下文對齊的有效性。
但在 LLaMA2-Chat 13B 和 LLaMA3-Instruct 70B 上,HASS 的第一步接受率相比 EAGLE-2 下降了。這可能是因為草稿模型關注后幾步的 token 生成而忽視了第一步的,但第一步的接受率對于接受長度非常關鍵。因此我們考慮調整訓練時每一步對齊的損失權重,來強調前幾步的重要性。具體的,我們對于第 j 步的訓練損失乘上權重 ,結果如表 5 和圖 6 所示。當
從 1.0 降到 0.5 時,草稿模型的接受長度不斷提高。其在第一步的接受率也對應增長,而后幾步的接受率有所下降。當
下降到 0.3 時,訓練過程過分強調了第一步 token 生成,導致了接受長度下降。我們將在多步對齊間取得平衡的探索留到后續工作中。
04 作者簡介
- 樂凡
小紅書中臺算法工程師,目前主要負責大語言模型的相關研究和應用。
- 曉丹
小紅書中臺算法工程師,目前主要負責大語言模型的相關研究和應用。
- 特圖
小紅書中臺算法基礎模型方向負責人,主要研究方向:多模態大模型 x 內容分發技術。
- 瑞格
小紅書中臺算法團隊負責人。