成人免费xxxxx在线视频软件_久久精品久久久_亚洲国产精品久久久_天天色天天色_亚洲人成一区_欧美一级欧美三级在线观看

SFT loss計算的那些坑,完美避開!!!

發布于 2024-12-11 10:48
瀏覽
0收藏

?SFT 可以說是 LLM 的基本操作了,如果只是想把 SFT 跑起來是非常簡單的,只需要構造 input_ids 和 labels,然后就可以把訓練跑起來。然而,這樣的訓練效率實際上非常低。

所以在訓練時,通常有兩個加速方法:

  • 多輪合并
  • packing

無論是哪種方法,加速后都需要保證 loss 和原來是等價的。本文主要介紹這兩種加速方法,以及 loss 計算時遇到的問題。

1.多輪合并

假設我們有一個對話,其中 user 和 bot 交互了 3 輪,我們可以構建三個樣本:

SFT loss計算的那些坑,完美避開!!!-AI.x社區

input_ids 就是對應的 token id,labels 輸入部分(白色)使用 -100,輸出部分(綠色)使用 input_ids。

這樣計算的 loss 可以表示為:

SFT loss計算的那些坑,完美避開!!!-AI.x社區

其中 l_i 表示第 i 個樣本的 loss,n_i 表示第 i 個樣本輸出的 token 數量 (對應綠色部分)。

這樣除了訓練比較慢,沒有什么別的問題。因為不同樣本之間有很多重復計算的前綴,實際上這部分計算一次就行。

2.加速計算

SFT loss計算的那些坑,完美避開!!!-AI.x社區

如果將三輪三個樣本合并成一個樣本,可以嘗試這種構造形式。

因為存在 causal attention mask,所以每個 token 只能看到前面的 token,計算上和之前是等價的。

但是這樣有一個坑:如果還是按照剛才的方式構建 input_ids 和 labels (白色用-100,綠色用input_ids)loss 計算是有問題的。

pytorch CrossEntropyLoss 計算 loss 按照下面的方法,默認是"mean"。

SFT loss計算的那些坑,完美避開!!!-AI.x社區

所以我們會得到這樣的 loss:

SFT loss計算的那些坑,完美避開!!!-AI.x社區

當不同輪次的輸出長度不同時,這種 loss 和剛才的不等價。多輪對話中輸出較短的權重被降低了,輸出較長的被提高了。所以結果就是短輸出的數據訓練不夠充分。

3.Packing

假設我們有兩個對話,第一個是單輪對話,第二個是三輪對話。

SFT loss計算的那些坑,完美避開!!!-AI.x社區

正確的 loss:

SFT loss計算的那些坑,完美避開!!!-AI.x社區

其中 l_ij 表示第 i 個樣本第 j 輪對話的 loss,n_ij 同理。

問題:真實場景中的訓練集文本長度長短不一,Padding 后矩陣非常稀疏,只有不到一半是有效計算。

加速計算:

SFT loss計算的那些坑,完美避開!!!-AI.x社區

將所有樣本拼接成一條,并且加入 attention mask 保證后面的樣本看不見前面的 token。

比如在 flash attention 中,可以調用 flash_attn_varlen_qkvpacked_func,并傳入 cu_seqlens 參數。

和之前一樣,如果不修改 loss 計算方法,packing 的樣本之間會存在因為長度不同,導致訓練不充分的問題。

4.正確方法

一般情況下,loss 計算會經歷三次平均:

  • micro batch 維度,分母是這個 micro batch 中的所有 label 不是 -100 的 token 數
  • DP 維度,分母是 DP size (和GPU數量相關)
  • 梯度累加維度,分母是梯度累加數

我們這里要做的就是禁用這三個平均,統一用這個 global batch 的對話輪數作為分母。

在新版 megatron 框架中,開啟開關 --calculate-per-token-loss 即可禁用 DP 和梯度累加的平均,然后修改 loss_func。

每個 micro batch 都需要返回這個 micro batch 的輪數,最后框架會自動將所有輪數求和,作為分母。對于分子,需要除以這個輪次的token 數。

正確實現代碼如下(loss_token_num, turn_num 是在構建 data 的時候構建的):

def loss_func(output_tensor, loss_mask, loss_token_num, turn_num):
    losses = output_tensor.view(-1).float()
    loss_mask = loss_mask.view(-1).float()
    loss_token_num = loss_token_num.view(-1).float()
    # label: [-100, -100, a, a, a, -100, b, b, -100, -100, c, c, c, -100, -100]
    # loss_mask: [0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0]
    # losses: [a0, a1, a2, a3, a4, b0, b1, b2, c0, c1, c2, c3, c4, d0, d1]
    # losses * loss_mask = [0, 0, a2, a3, a4, 0, b1, b2, 0, 0, c2, c3, c4, 0, 0]
    # loss_token_num: [3, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1]
    # losses * loss_mask / loss_token_num = [0, 0, a2/3, a3/3, a4/3, 0, b1/2, b2/2, 0, 0, c2/3, c3/3, c4/3, 0, 0]
    # sum = 1/3 (a2 + a3 + a4) + 1/2 (b1 + b2) + 1/3 (c2 + c3 + c4)
    loss = torch.sum(losses * loss_mask / loss_token_num)
    loss_and_turn_num = torch.cat([loss.view(1), turn_num.view(1)])
    # Reduce loss for logging.
    loss_and_turn_num = loss_and_turn_num.clone().detach()
    torch.distributed.all_reduce(loss_and_turn_num, group=mpu.get_data_parallel_group())
    # 新版返回結構,開啟 calculate_per_token_loss 開關后,返回三個值
    # 第一個是反向傳播實際使用的 loss, 所有 packing 的 loss 求和
    # 第二個是 turn_num, 優化器狀態更新時會使用對這個值求和然后縮放梯度
    # 第三個是用于日志打印的 loss, 包含兩個值,第一個是所有 loss 求和作為分子,第二個是所有 turn_num 求和作為分母
    return loss, turn_num, {"lm loss": (loss_and_turn_num[0], loss_and_turn_num[1])}

5.總結

在 SFT 時,如果要加速,需要注意:

  • 不同樣本之間是等價的;
  • 不同輪次之間也是等價的。

在合并多輪 / packing 時,需要修改 loss 計算方法,為每個 token 設置正確的權重,并且關閉 DP / 梯度累加的平均。?

本文轉載自??丁師兄大模型??,作者:Ethan Yan ????

收藏
回復
舉報
回復
相關推薦
主站蜘蛛池模板: 天天爱天天操 | 国产成人午夜电影网 | va精品 | 亚洲国产成人av好男人在线观看 | 在线视频成人 | 午夜视频免费在线 | 一区二区三区亚洲 | 日韩欧美高清 | 少妇一级淫片免费播放 | 97精品超碰一区二区三区 | 日韩和的一区二区 | 99久久精品免费视频 | 亚洲精品一区国语对白 | 色综合视频 | 久久伊人精品一区二区三区 | 91精品国产91久久久久久最新 | 国产精品a久久久久 | 日韩一区二区在线免费观看 | 亚洲欧美精品一区 | 欧美寡妇偷汉性猛交 | 九九99久久 | 一级片av| 久久久男人的天堂 | 男女爱爱福利视频 | 看a网站 | 91中文在线观看 | 日本高清精品 | 久久精品国产99国产精品 | 日韩一区二区三区在线视频 | 日本在线视频一区二区 | 9999久久 | 精品一区二区三区中文字幕 | 欧美成人免费在线 | 国产xxxx岁13xxxxhd | 日韩午夜一区二区三区 | 久久久久久九九九九九九 | 91伊人网| 人人擦人人干 | 国产区一区 | 第一区在线观看免费国语入口 | 一级黄色片一级黄色片 |