SFT loss計算的那些坑,完美避開!!!
?SFT 可以說是 LLM 的基本操作了,如果只是想把 SFT 跑起來是非常簡單的,只需要構造 input_ids 和 labels,然后就可以把訓練跑起來。然而,這樣的訓練效率實際上非常低。
所以在訓練時,通常有兩個加速方法:
- 多輪合并
- packing
無論是哪種方法,加速后都需要保證 loss 和原來是等價的。本文主要介紹這兩種加速方法,以及 loss 計算時遇到的問題。
1.多輪合并
假設我們有一個對話,其中 user 和 bot 交互了 3 輪,我們可以構建三個樣本:
input_ids 就是對應的 token id,labels 輸入部分(白色)使用 -100,輸出部分(綠色)使用 input_ids。
這樣計算的 loss 可以表示為:
其中 l_i 表示第 i 個樣本的 loss,n_i 表示第 i 個樣本輸出的 token 數量 (對應綠色部分)。
這樣除了訓練比較慢,沒有什么別的問題。因為不同樣本之間有很多重復計算的前綴,實際上這部分計算一次就行。
2.加速計算
如果將三輪三個樣本合并成一個樣本,可以嘗試這種構造形式。
因為存在 causal attention mask,所以每個 token 只能看到前面的 token,計算上和之前是等價的。
但是這樣有一個坑:如果還是按照剛才的方式構建 input_ids 和 labels (白色用-100,綠色用input_ids)loss 計算是有問題的。
pytorch CrossEntropyLoss 計算 loss 按照下面的方法,默認是"mean"。
所以我們會得到這樣的 loss:
當不同輪次的輸出長度不同時,這種 loss 和剛才的不等價。多輪對話中輸出較短的權重被降低了,輸出較長的被提高了。所以結果就是短輸出的數據訓練不夠充分。
3.Packing
假設我們有兩個對話,第一個是單輪對話,第二個是三輪對話。
正確的 loss:
其中 l_ij 表示第 i 個樣本第 j 輪對話的 loss,n_ij 同理。
問題:真實場景中的訓練集文本長度長短不一,Padding 后矩陣非常稀疏,只有不到一半是有效計算。
加速計算:
將所有樣本拼接成一條,并且加入 attention mask 保證后面的樣本看不見前面的 token。
比如在 flash attention 中,可以調用 flash_attn_varlen_qkvpacked_func,并傳入 cu_seqlens 參數。
和之前一樣,如果不修改 loss 計算方法,packing 的樣本之間會存在因為長度不同,導致訓練不充分的問題。
4.正確方法
一般情況下,loss 計算會經歷三次平均:
- micro batch 維度,分母是這個 micro batch 中的所有 label 不是 -100 的 token 數
- DP 維度,分母是 DP size (和GPU數量相關)
- 梯度累加維度,分母是梯度累加數
我們這里要做的就是禁用這三個平均,統一用這個 global batch 的對話輪數作為分母。
在新版 megatron 框架中,開啟開關 --calculate-per-token-loss 即可禁用 DP 和梯度累加的平均,然后修改 loss_func。
每個 micro batch 都需要返回這個 micro batch 的輪數,最后框架會自動將所有輪數求和,作為分母。對于分子,需要除以這個輪次的token 數。
正確實現代碼如下(loss_token_num, turn_num 是在構建 data 的時候構建的):
def loss_func(output_tensor, loss_mask, loss_token_num, turn_num):
losses = output_tensor.view(-1).float()
loss_mask = loss_mask.view(-1).float()
loss_token_num = loss_token_num.view(-1).float()
# label: [-100, -100, a, a, a, -100, b, b, -100, -100, c, c, c, -100, -100]
# loss_mask: [0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0]
# losses: [a0, a1, a2, a3, a4, b0, b1, b2, c0, c1, c2, c3, c4, d0, d1]
# losses * loss_mask = [0, 0, a2, a3, a4, 0, b1, b2, 0, 0, c2, c3, c4, 0, 0]
# loss_token_num: [3, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1]
# losses * loss_mask / loss_token_num = [0, 0, a2/3, a3/3, a4/3, 0, b1/2, b2/2, 0, 0, c2/3, c3/3, c4/3, 0, 0]
# sum = 1/3 (a2 + a3 + a4) + 1/2 (b1 + b2) + 1/3 (c2 + c3 + c4)
loss = torch.sum(losses * loss_mask / loss_token_num)
loss_and_turn_num = torch.cat([loss.view(1), turn_num.view(1)])
# Reduce loss for logging.
loss_and_turn_num = loss_and_turn_num.clone().detach()
torch.distributed.all_reduce(loss_and_turn_num, group=mpu.get_data_parallel_group())
# 新版返回結構,開啟 calculate_per_token_loss 開關后,返回三個值
# 第一個是反向傳播實際使用的 loss, 所有 packing 的 loss 求和
# 第二個是 turn_num, 優化器狀態更新時會使用對這個值求和然后縮放梯度
# 第三個是用于日志打印的 loss, 包含兩個值,第一個是所有 loss 求和作為分子,第二個是所有 turn_num 求和作為分母
return loss, turn_num, {"lm loss": (loss_and_turn_num[0], loss_and_turn_num[1])}
5.總結
在 SFT 時,如果要加速,需要注意:
- 不同樣本之間是等價的;
- 不同輪次之間也是等價的。
在合并多輪 / packing 時,需要修改 loss 計算方法,為每個 token 設置正確的權重,并且關閉 DP / 梯度累加的平均。?
