無損減少80%激活值內存,提升5倍訓練序列長度,僅需兩行代碼
本文的第一作者羅琪竣、第二作者李夢琦為香港中文大學(深圳)計算機科學博士生,本文在上海交通大學趙磊老師、香港中文大學(深圳)李肖老師的指導下完成。
長序列訓練對于模型的長序列推理等能力至關重要。隨著序列長度增加,訓練所需儲存的激活值快速增加,占據訓練的大部分內存。即便使用梯度檢查點(gradient checkpointing)方法,激活值依然占據大量內存,限制訓練所能使用的序列長度。
來自港中文(深圳)和上海交通大學的團隊提出 StreamBP 算法。通過對鏈式法則進行線性分解和分步計算,StreamBP 將大語言模型訓練所需的激活值內存(logits 和 layer activation)降低至梯度檢查點(gradient checkpointing)的 20% 左右。
- 論文標題:StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs
- 論文:https://arxiv.org/abs/2506.03077
- 代碼:https://github.com/Ledzy/StreamBP
在相同內存限制下,StreamBP 最大序列長度為梯度檢查點的 2.8-5.5 倍。在相同序列長度下,StreamBP 的速度和梯度檢查點接近甚至更快。StreamBP 適用于 SFT、GRPO、PPO 和 DPO 等常見 LLM 目標函數。代碼已開源,可集成至現有訓練代碼。
激活值內存和梯度檢查點
在反向傳播(Backpropagation, BP)的過程中,計算模型梯度需要用到模型的中間輸出(激活值)。舉例來說,對于模型中的線性變換的梯度為
,因而計算
的梯度時需要儲存相應的激活值
。
對于模型中的任意函數變換 的梯度由以下鏈式法則計算:
其中 L 為目標函數,為 Jacobian 矩陣。為了計算以上 Jacobian-vector product,需要在模型 forward 時儲存函數變換
的中間值(激活值),其內存消耗與 batch size、序列長度以及中間值維度正相關。
為了減少激活值的內存消耗,梯度檢查點(gradient checkpointing)方法在 forward 時只儲存每一層網絡的輸入,而不儲存該層的中間值。在 backward 至該層時,將重新 forward 此層輸入來計算得到該層激活值。使用梯度檢查點時儲存的激活值包括:
- 所有層的輸入,一般為激活值內存的 5%-15%。
- 單層的完整激活值,占據超過 85% 的激活值內存。
StreamBP 的核心思想
不同于梯度檢查點,StreamBP 避免儲存單層的完整激活值,而將單層的 BP 過程進行線性分解,序列化計算并累加。注意到對于函數變換,鏈式法則存在以下線性分解:
StreamBP 基于以下觀察:對于 LLM 中的大部分函數變換,如 Transformer 層、lmhead 層,可通過策略性地將輸出分塊
,使得計算塊 Jacobian-vector product
所需的激活值遠小于計算完整的 Jacobian-vector product?;谠撚^察,StreamBP 依次計算上式中 D 個塊的 Jacobian-vector product 并累加,得到準確的梯度。
為了計算塊 Jacobian-vector product,需要分析輸入和輸出的相關性,每次 forward 塊輸入
得到塊輸出
,建立對應子計算圖。以簡單的線性變換
為例,輸出和輸入在行維度上一一對應。StreamBP 按行分塊,每次計算單行的 Jacobian-vector product 并累加。下圖對比了標準 BP 和 StreamBP 在上述線性變換下的實現:
D 步累加得到的和
即為
和
準確梯度。相比于標準 BP,StreamBP 僅需儲存
和
,且總計算 FLOPs 相同。下表為 StreamBP 和標準 BP 的內存和時間對比:
LLM 訓練中的 StreamBP
StreamBP 應用于 LLM 中的 Transformer 層和 lmhead 層,分別用于降低層激活值和 logits 的內存消耗。
與線性變換不同,由于 Transformer 層存在注意力機制,塊輸出并非僅由對應位置的塊輸入
決定,而與該塊及以前所有位置的輸入
都有關。StreamBP 利用
只與塊
有關的性質,建立了如下計算圖:
StreamBP 所需儲存的激活值和注意力掩碼(橙色)大幅低于梯度檢查點(橙色 + 白色部分)。
對于 lmhead 層,當以 SFT 或 GRPO 為目標函數時,觀察到不同位置的 logits 對于目標函數的影響相互獨立。因此,StreamBP 從序列維度分塊,每次計算單塊損失函數的梯度,從而只需儲存單塊 logits 和 logits 梯度。
圖:StreamBP for SFT
圖:StreamBP for GRPO
對于 DPO,由于非線性 sigmoid 函數的存在,每個位置的 logits 對于目標函數的影響并不獨立。StreamBP 利用 logits 梯度在序列維度的獨立性,分塊進行梯度計算。
圖:StreamBP for DPO
實驗結果
我們在單張 A800-80GB GPU 上測試了不同大小的模型,StreamBP 的最大 BP 序列長度為標準 BP 的 23-36 倍,梯度檢查點的 2.5-5.5 倍。
圖:不同序列長度下的 BP 峰值內存
在現有 Transformers 框架下,StreamBP 的實現可避免計算掩碼部分的 pre-attention score(見論文 3.2.2 部分),在長序列訓練下相較于梯度檢查點實現了加速。
通過使用 StreamBP,不同目標函數下最大的序列長度得到了大幅提升。在同樣的序列長度下,StreamBP 允許更大的批處理大小以加速訓練。
表:Qwen 3-4B 單個樣本 BP 時間,序列長度為 9000。
在 Deepspeed ZeRO 分布式訓練模式下,Distributed StreamBP 比梯度檢查點的最大可訓練序列長度提升了5—5.6倍。