LLM模型貪婪、溫度、Top-k、核采樣方式的區別(附代碼與示例)
在自然語言生成任務中,不同的采樣技術用于從語言模型的輸出中選擇下一個生成的單詞或詞語。這些技術包括貪婪采樣、溫度采樣、Top-k采樣和核(Nucleus)采樣。它們在選擇生成單詞的過程中有不同的策略,本文將介紹這四種采樣方式的區別。
1. 貪婪采樣 (Greedy Sampling)
貪婪采樣是一種直接選擇最可能的下一個詞的策略。
具體步驟為:從模型輸出的logits中,找到概率最大的那個詞,直接選擇它作為輸出。
實現代碼:
class GreedySampler(Sampler):
def __call__(self, logits: torch.Tensor):
return logits.argmax(dim=-1)
優點:
- 簡單且計算效率高。
- 保證每一步選擇最有可能的結果。
缺點:
- 可能會導致生成的文本非常重復和缺乏多樣性。
- 貪婪采樣只關注當前概率最大的詞,忽略了其他潛在的好選擇,容易陷入局部最優解。
2. 帶溫度的采樣 (Temperature Sampling)
溫度采樣通過引入一個溫度參數來調整輸出概率的分布,以控制生成文本的多樣性。溫度 T 的作用是平滑或銳化概率分布:
- 當 T = 1 時,采樣為標準隨機采樣。
- 當 T < 1 時,概率分布變得更尖銳,模型更傾向于選擇最可能的詞。
- 當 T > 1 時,概率分布變得更加平滑,模型會更多地探索低概率的詞。
實現代碼
class TemperatureSampler(Sampler):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def __call__(self, logits: torch.Tensor):
dist = Categorical(logits=logits / self.temperature)
return dist.sample()
優點:
- 提供了生成文本的多樣性,尤其是在溫度高時。
- 通過調整溫度參數,可以控制探索(隨機性)與利用(選擇高概率詞)之間的平衡。
缺點:
- 溫度的選擇需要仔細調節,不同任務或場景下對溫度的需求可能不同。
- 溫度過低時,生成的文本趨向于貪婪采樣;溫度過高時,生成的文本可能過于隨機。
3. Top-k采樣
Top-k采樣限制了每次生成時候的候選詞數量,模型只會從概率前k個最高的詞中進行采樣,而忽略其他可能性較小的詞。
實現代碼:
class TopKSampler(Sampler):
def __init__(self, k: int, sampler: Sampler):
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
zeros = logits.new_ones(logits.shape) * float('-inf')
values, indices = torch.topk(logits, self.k, dim=-1)
zeros.scatter_(-1, indices, values)
return self.sampler(zeros)
優點:
- 提供了對生成詞匯的嚴格控制,避免生成概率非常低的詞。
- 通過限制候選詞的數量,避免了一些罕見或不合邏輯的詞被選中。
缺點:
- 需要設定一個合適的 k 值,如果 k 值太小,生成的文本可能會缺乏多樣性;如果 k 值太大,則效果與標準采樣相似。
4. 核采樣 (Nucleus Sampling)
核采樣是一種自適應的采樣方法,它選擇的候選詞集合 V(p) 是滿足累計概率和大于或等于給定閾值 p 的最小詞匯子集。與Top-k采樣不同,核采樣的候選詞數量不是固定的,而是基于累計概率動態確定的。
示例
假設同樣的語境:“今天的天氣很”,但這次我們將會有不同的詞匯及其概率分布,我們也會使用不同的閾值 ( p ) 來展示如何動態確定選詞數量。
(1) 模型預測的詞匯概率
- 好:0.4
- 冷:0.3
- 熱:0.2
- 潮濕:0.05
- 多變:0.03
- 干燥:0.02
(2) 排序與累積概率
按概率從高到低排序并計算累積概率:
- 好:0.4
- 冷:0.7 (0.4 + 0.3)
- 熱:0.9 (0.7 + 0.2)
- 潮濕:0.95 (0.9 + 0.05)
- 多變:0.98 (0.95 + 0.03)
- 干燥:1.00 (0.98 + 0.02)
(3) 確定核集合
這次,我們將選擇不同的閾值 ( p ) 來觀察核集合如何變化:
- **當 ( p = 0.7 )**:核集合包括:“好”和“冷”,因為它們的累積概率首次超過 0.7。
- **當 ( p = 0.9 )**:核集合擴展到:“好”,“冷”,和“熱”,因為它們的累積概率首次超過 0.9。
- **當 ( p = 0.95 )**:核集合進一步擴展到:“好”,“冷”,“熱”和“潮濕”,因為這是累積概率首次超過 0.95。
(4) 抽樣
在每種情況下,我們從對應的核集合中隨機選取一個詞作為下一個詞。選擇的范圍和多樣性取決于 ( p ) 值的大小,而詞的數量是根據這個閾值動態確定的,不是固定的。
實現代碼
class NucleusSampler(Sampler):
"""
## Nucleus 采樣器
Nucleus 采樣器根據給定的概率 p 選擇詞匯的一個子集,并從中進行采樣。
"""
def __init__(self, p: float, sampler: Sampler):
"""
### 初始化
:param p: 要選擇的令牌概率之和,即 p 值。
:param sampler: 用于從選定令牌中進行采樣的采樣器。
"""
# 保存 p 值
self.p = p
# 保存采樣器
self.sampler = sampler
# 初始化 softmax 層,用于將 logits 轉換為概率
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
"""
### 從 logits 中進行 Nucleus 采樣
:param logits: 輸入的 logits 張量,形狀為 (batch_size, num_tokens)。
:return: 采樣得到的令牌索引,形狀為 (batch_size,)。
"""
# 獲取概率 P(x_i | x_1:i-1)
probs = self.softmax(logits)
# 按降序對概率進行排序,并獲取排序后的索引
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
# 按排序順序獲取概率的累積總和
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
# 找出累積總和小于 p 的令牌
nucleus = cum_sum_probs < self.p
# 在前面加一個 True,這樣我們可以在累積概率小于 p 的最小令牌數量之后添加一個令牌
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
# 獲取對數概率并掩蓋非核部分
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
# 使用采樣器從排序后的對數概率中進行采樣
sampled_sorted_indexes = self.sampler(sorted_log_probs)
# 獲取實際的索引
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
# 返回采樣得到的令牌索引
return res.squeeze(-1)
優點:
- 靈活性強,自動調整候選詞集合,避免了固定的詞數限制。
- 在生成文本時能夠更好地平衡多樣性與高概率詞的利用,表現優于Top-k采樣。
缺點:
- 參數 p 的選擇需要調節,不同任務可能需要不同的 p 值。
- 計算復雜度較高,尤其是當處理較大的詞匯表時。
總結
采樣方法 | 優點 | 缺點 |
貪婪采樣 | 簡單、高效,始終選擇最有可能的詞 | 文本生成可能單一,缺乏多樣性 |
溫度采樣 | 通過調整溫度控制多樣性,適應性強 | 溫度的調節需要謹慎,過高或過低的溫度可能產生不理想的結果 |
Top-k采樣 | 控制候選詞數量,避免選擇低概率詞 |
|
核采樣 | 動態選擇候選詞集合,更靈活,生成文本質量較高 | 參數 |
每種采樣方式都有其適用的場景,根據具體的應用和對生成文本的要求,可以選擇不同的采樣策略。