LLM實踐系列-細聊LLM的拒絕采樣
最近學強化的過程中,總是遇到“拒絕采樣”這個概念,我嘗試科普一下,爭取用最大白話的方式讓每個感興趣的同學都理解其中思想。
拒絕采樣是 LLM 從統計學借鑒過來的一個概念。其實大家很早就接觸過這個概念,每個刷過 leetcode 的同學大概率都遇到過這樣一個問題:“如何用一枚骰子獲得 1/7 的概率?”
答案很簡單:把骰子扔兩次,獲得 6 * 6 = 36 種可能的結果,丟棄最后一個結果,剩下的 35 個結果平分成 7 份,對應的概率值便為 1/7 。使用這種思想,我們可以利用一枚骰子獲得任意 1/N 的概率。
在這個問題中,我們可以看到拒絕采樣的一些關鍵要素:
- 采樣:從易于采樣的分布(兩個骰子的所有可能結果)中生成樣本;
- 縮放:(扔兩次骰子)獲得更大的樣本分布;
- 拒絕:丟棄(拒絕)不符合條件的樣本(第36種情況);
- 接受:對于剩下的樣本,重新調整概率(通過分組),獲得目標概率分布。
用大白話來總結就是:我們想獲得某個分布(1/7)的樣本,但卻沒有辦法。于是我們對另外一個分布(1/6)進行采樣,但這個分布不能涵蓋原始分布,需要我們縮放這個分布(扔兩次)來包裹起來目標分布。然后,我們以某種規則拒絕明顯不是目標分布的采樣點,剩下的采樣點就可以看作是從目標分布采樣出來的了。
統計學的拒絕采樣
LLM 的拒絕采樣
LLM 的拒絕采樣操作起來非常簡單:讓自己的模型針對 prompt 生成多個候選 response,然后用 reward_model 篩選出來高質量的 response (也可以是 pair 對),拿來再次進行訓練。
解剖這個過程:
- 提議分布是我們自己的模型,目標分布是最好的語言模型;
- prompt + response = 一個采樣結果;
- do_sample 多次 = 縮放提議分布(也可以理解為扔多次骰子);
- 采樣結果得到 reward_model 的認可 = 符合目標分布。
經過這一番操作,我們能獲得很多的訓練樣本,“這些樣本既符合最好的語言模型的說話習慣,又不偏離原始語言模型的表達習慣”,學習它們就能讓我們的模型更接近最好的語言模型。
統計學與 LLM 的映射關系
統計學的拒絕采樣有幾個關鍵要素:
- 原始分布采樣困難,提議分布采樣簡單;
- 提議分布縮放后能涵蓋原始分布;
- 有辦法判斷從提議分布獲取的樣本是否屬于原始分布,這需要我們知道原始分布的密度函數。
LLM 的拒絕采樣也有幾個對應的關鍵要素:
- 我們不知道最好的語言模型怎么說話,但我們知道自己的語言模型如何說話;
- 讓自己的語言模型反復說話,得到的語料大概率會包括最好的語言模型的說話方式;
- reward_model 可以判斷某句話是否屬于最好的語言模型的說話方式。
目前為止,是不是看上去很有道理,很好理解。但其實這里有一個致命的邏輯漏洞:為什么我們的模型反復 do_sample,就一定能覆蓋最好的語言模型呢?這不合邏輯啊,狗嘴里采樣多少次也吐不出象牙啊。
緊接著,就需要我們引出另一個概念了:RLHF 的優化目標是什么?
RLHF 與拒絕采樣
