LLM高效推理:KV緩存與分頁注意力機制深度解析
隨著大型語言模型(LLM)規(guī)模和復(fù)雜性的持續(xù)增長,高效推理的重要性日益凸顯。KV(鍵值)緩存與分頁注意力是兩種優(yōu)化LLM推理的關(guān)鍵技術(shù)。本文將深入剖析這些概念,闡述其重要性,并探討它們在僅解碼器(decoder-only)模型中的工作原理。
常規(guī)推理機制
首先,我們通過一個簡單的例子來理解Transformer模型中典型的推理過程。假設(shè)我們需要生成短語:
“The quick brown fox jumped”
以下是常規(guī)推理的簡化實現(xiàn):
import numpy as np
# 簡化的嵌入表示,僅用于演示
embeddings = {
'The': np.array([1, 0, 0, 0]),
'quick': np.array([0, 1, 0, 0]),
'brown': np.array([0, 0, 1, 0]),
'fox': np.array([0, 0, 0, 1]),
'jumped': np.array([1, 1, 0, 0])
}
# 權(quán)重矩陣(簡化)
W_Q = W_K = W_V = np.array([[1, 0],
[0, 1],
[0, 0],
[0, 0]])
def compute_attention(self, input_words):
# 將單詞轉(zhuǎn)換為嵌入向量
E = np.array([embeddings[word] for word in input_words])
# 計算所有token的K和V矩陣
K = E @ W_K # 形狀: (seq_len, 2)
V = E @ W_V # 形狀: (seq_len, 2)
# 計算最后一個token的Q矩陣
Q = E[-1] @ W_Q # 形狀: (1, 2)
# 計算縮放的點積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 將注意力權(quán)重應(yīng)用于V矩陣
output = attention_weights @ V # 形狀: (1, 2)
return output
以下是逐步生成的過程:
# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention(input_words_step2)
# 步驟3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
output_step3 = compute_attention(input_words_step3)
冗余計算:觀察上述代碼可以發(fā)現(xiàn)對于每個新生成的token:
- 需要為所有先前的token重新計算K和V矩陣。
- 矩陣的大小隨著token數(shù)量的增加而增大。
- 存在大量不必要的重復(fù)計算。
KV緩存機制
當(dāng)使用Transformer模型生成文本時,通過緩存鍵(K)和值(V)矩陣,可以顯著優(yōu)化推理過程。下圖展示了KV緩存的工作原理:
在上圖中:
- q_new表示最新token的查詢向量。
- K_prev和V_prev是從先前計算中緩存得到的鍵和值矩陣。
- k_new和v_new僅為當(dāng)前新token計算。
- 藍(lán)色箭頭表示如何利用緩存值和新值計算注意力。
以下是KV緩存的實現(xiàn)示例:
def compute_attention_with_cache(self, input_words):
"""使用KV緩存計算注意力"""
# 獲取新token(序列中的最后一個單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 更新緩存的K和V矩陣
if self.cached_K is None:
self.cached_K = K_new.reshape(1, -1) # 形狀: (1, 2)
self.cached_V = V_new.reshape(1, -1) # 形狀: (1, 2)
else:
self.cached_K = np.vstack([self.cached_K, K_new]) # 形狀: (seq_len, 2)
self.cached_V = np.vstack([self.cached_V, V_new]) # 形狀: (seq_len, 2)
# 計算最后一個token的Q矩陣
Q = e_new @ W_Q # 形狀: (2,)
# 使用緩存的K矩陣計算縮放的點積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ self.cached_K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 使用緩存的V矩陣計算注意力輸出
output = attention_weights @ self.cached_V # 形狀: (1, 2)
return output
以下是逐步生成的過程:
# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention_with_cache(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention_with_cache(input_words_step2)
# 步驟 3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
output_step3 = compute_attention_with_cache(input_words_step3)
比較有無KV緩存的推理計算
內(nèi)存需求與挑戰(zhàn)
我們來看一個使用典型模型參數(shù)的實際例子:
- 序列長度: 4096
- 層數(shù): 32
- 注意力頭數(shù): 32
- 頭維度: 128
- 精度: FP16 (2 bytes)
每個token所需的內(nèi)存:
KV_cache_per_token = 2×num_layers×(num_heads×head_dim)×precision
= 2 × 32 × (32 × 128) × 2 bytes
= 2 × 32 × 4096 × 2 bytes
= 524,288 bytes
≈ 0.5 MB
KV緩存的低效性
盡管KV緩存顯著提高了計算效率,但它也帶來了內(nèi)存管理方面的挑戰(zhàn)。以下是三種主要的內(nèi)存低效類型:
內(nèi)部碎片
- 由因未知輸出長度而導(dǎo)致的過度分配引起。
- 示例:在圖像中,2040個槽位從未被使用。
- 影響:可能浪費高達(dá)60-80%的已分配內(nèi)存。
- 解決方案:更精確的輸出長度估計或動態(tài)分配策略。
預(yù)留浪費
- 為將來的token生成而預(yù)留的內(nèi)存。
- 在圖像中顯示為“3 slots future used (reserved)”。
- 維持生成連續(xù)性的必要措施。
- 可以通過更好地預(yù)測所需的未來槽位來優(yōu)化。
外部碎片
- 由處理具有不同序列長度的多個請求導(dǎo)致。
- 在不同請求之間創(chuàng)建內(nèi)存間隙。
- 解決方案包括內(nèi)存碎片整理和智能請求批處理。
如上圖所示,通常僅有20-40%的KV緩存被用于存儲實際的token狀態(tài)。
分頁注意力:解決內(nèi)存低效的方案
為了應(yīng)對這些內(nèi)存挑戰(zhàn),可以采用分頁注意力機制。
分頁注意力是一種用于有效處理Transformer模型中長序列的技術(shù),它通過將注意力計算分解為更小、更易于管理的“頁”或“塊”來實現(xiàn)。這種方法降低了內(nèi)存消耗和計算復(fù)雜度,從而能夠處理原本因過大而無法放入內(nèi)存的序列。
def compute_attention_with_paging(self, input_words):
"""使用分頁KV緩存計算注意力"""
# 獲取新token(序列中的最后一個單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 確定當(dāng)前頁的索引
total_tokens = sum(len(K_page) for K_page in self.cached_K_pages) + 1
current_page_idx = (total_tokens - 1) // PAGE_SIZE
# 如果需要,初始化新頁
if len(self.cached_K_pages) <= current_page_idx:
self.cached_K_pages.append([])
self.cached_V_pages.append([])
# 將K和V添加到當(dāng)前頁的緩存中
self.cached_K_pages[current_page_idx].append(K_new)
self.cached_V_pages[current_page_idx].append(V_new)
# 計算當(dāng)前token的Q矩陣
Q = e_new @ W_Q # Shape: (2,)
# 僅在當(dāng)前頁內(nèi)計算注意力
K_current_page = np.array(self.cached_K_pages[current_page_idx])
V_current_page = np.array(self.cached_V_pages[current_page_idx])
# 添加縮放因子,用于點積注意力
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K_current_page.T) / scale
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, current_page_size)
# 將注意力權(quán)重應(yīng)用于當(dāng)前頁中的V矩陣
output = attention_weights @ V_current_page
return output
以下是逐步生成的過程:
# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention_with_paging(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention_with_paging(input_words_step2)
# 步驟3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
output_step3 = compute_attention_with_paging(input_words_step3)
為何需要分頁注意力?
- 內(nèi)存約束:由于注意力矩陣的規(guī)模與序列長度呈平方關(guān)系,Transformer模型在處理長序列時面臨嚴(yán)重的內(nèi)存限制。
- 長序列處理:在諸如語言建模或文檔摘要等任務(wù)中,序列可能非常長。
- 效率:通過以分頁的方式處理注意力計算,可以將內(nèi)存使用量保持在一個常量水平,從而不受序列長度的影響。
分頁注意力如何工作?
- 分割序列:將輸入序列分割成更小的塊或頁。
- 局部注意力:在每個頁內(nèi)計算注意力。
- 跨頁注意力:可選地,允許有限的跨頁注意力,以捕獲頁之間的依賴關(guān)系。
- 滑動窗口:使用重疊的頁來確保連續(xù)性。
上述實現(xiàn)僅限于局部注意力,跨頁注意力和滑動窗口的實現(xiàn)超出了本文的范圍,將在后續(xù)文章中詳細(xì)介紹。
分頁注意力的討論
優(yōu)勢
- 內(nèi)存效率:注意力計算被限制在頁大小內(nèi),內(nèi)存使用量保持恒定,不受總序列長度的影響。
- 計算效率:降低了注意力計算的復(fù)雜度。
- 可擴展性:能夠處理原本無法放入內(nèi)存的超長序列。
權(quán)衡與考慮
- 上下文信息受限:模型會丟失跨頁的一些依賴關(guān)系,這對于需要全局上下文的任務(wù)可能很重要。
可能的解決方案:
- 重疊頁:允許頁之間重疊一定數(shù)量的token,重疊區(qū)域的token可以關(guān)注前一頁的token。
- 分層注意力:使用更高層次的注意力機制來連接跨頁的信息。
重疊頁、分層注意力、跨頁注意力和滑動窗口的完整實現(xiàn)超出了本文的范圍。
以下實現(xiàn)僅捕獲局部注意力,作為示例不應(yīng)在實際應(yīng)用中使用:
# 本實現(xiàn)僅為演示和理解目的而設(shè)計的簡化版本。
# 實際應(yīng)用中需要更高效和可擴展的實現(xiàn)。
import numpy as np
embeddings = {
'The': np.array([1, 0, 0, 0]),
'quick': np.array([0, 1, 0, 0]),
'brown': np.array([0, 0, 1, 0]),
'fox': np.array([0, 0, 0, 1]),
'jumped': np.array([1, 1, 0, 0])
}
W_Q = W_K = W_V = np.array([[1, 0],
[0, 1],
[0, 0],
[0, 0]])
PAGE_SIZE = 2 # 演示用的小頁尺寸
class AttentionWithCache:
def __init__(self):
self.cached_K = None # 形狀: (seq_len, 2)
self.cached_V = None # 形狀: (seq_len, 2)
self.cached_K_pages = [] # 包含K向量的頁列表
self.cached_V_pages = [] # 包含V向量的頁列表
def softmax(self, x, axis=-1):
"""
為x中的每組分?jǐn)?shù)計算Softmax值。
包含數(shù)值穩(wěn)定性改進(jìn)。
"""
# 應(yīng)用最大值減法以提高數(shù)值穩(wěn)定性
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def compute_attention(self, input_words):
# 將單詞轉(zhuǎn)換為嵌入向量
E = np.array([embeddings[word] for word in input_words])
# 計算所有token的K和V矩陣
K = E @ W_K # 形狀: (seq_len, 2)
V = E @ W_V # 形狀: (seq_len, 2)
# 計算最后一個token的Q矩陣
Q = E[-1] @ W_Q # 形狀: (1, 2)
# 計算縮放的點積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 將注意力權(quán)重應(yīng)用于V矩陣
output = attention_weights @ V # 形狀: (1, 2)
return output
def compute_attention_with_cache(self, input_words):
"""使用KV緩存計算注意力"""
# 獲取新token(序列中的最后一個單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 更新緩存的K和V矩陣
if self.cached_K is None:
self.cached_K = K_new.reshape(1, -1) # 形狀: (1, 2)
self.cached_V = V_new.reshape(1, -1) # 形狀: (1, 2)
else:
self.cached_K = np.vstack([self.cached_K, K_new]) # 形狀: (seq_len, 2)
self.cached_V = np.vstack([self.cached_V, V_new]) # 形狀: (seq_len, 2)
# 計算最后一個token的Q矩陣
Q = e_new @ W_Q # 形狀: (2,)
# 使用緩存的K矩陣計算縮放的點積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ self.cached_K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 使用緩存的V矩陣計算注意力輸出
output = attention_weights @ self.cached_V # 形狀: (1, 2)
return output
def compute_attention_with_paging(self, input_words):
"""使用分頁KV緩存計算注意力"""
# 獲取新token(序列中的最后一個單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 確定當(dāng)前頁的索引
total_tokens = sum(len(K_page) for K_page in self.cached_K_pages) + 1
current_page_idx = (total_tokens - 1) // PAGE_SIZE
# 如果需要,初始化新頁
if len(self.cached_K_pages) <= current_page_idx:
self.cached_K_pages.append([])
self.cached_V_pages.append([])
# 將K和V添加到當(dāng)前頁的緩存中
self.cached_K_pages[current_page_idx].append(K_new)
self.cached_V_pages[current_page_idx].append(V_new)
# 計算當(dāng)前token的Q矩陣
Q = e_new @ W_Q # Shape: (2,)
# 僅在當(dāng)前頁內(nèi)計算注意力
K_current_page = np.array(self.cached_K_pages[current_page_idx])
V_current_page = np.array(self.cached_V_pages[current_page_idx])
# 添加縮放因子,用于點積注意力
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K_current_page.T) / scale
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, current_page_size)
# 將注意力權(quán)重應(yīng)用于當(dāng)前頁中的V矩陣
output = attention_weights @ V_current_page
return output
def compare_implementations():
print("原始實現(xiàn):")
attention1 = AttentionWithCache()
# 使用原始方法處理序列
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention1.compute_attention(words)
print(f"處理 {words} 后的輸出:")
print(f"Output: {output}")
print("\nKV緩存實現(xiàn):")
attention2 = AttentionWithCache()
# 使用KV緩存處理序列
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention2.compute_attention_with_cache(words)
print(f"處理 {words} 后的輸出:")
print(f"Output: {output}")
print("\n分頁注意力實現(xiàn):")
attention3 = AttentionWithCache()
# 使用分頁注意力處理序列
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention3.compute_attention_with_paging(words)
print(f"處理 {words} 后的輸出:")
print(f"Output: {output}")
print(f"頁數(shù): {len(attention3.cached_K_pages)}")
print(f"當(dāng)前頁大小: {len(attention3.cached_K_pages[-1])}\n")
if __name__ == "__main__":
compare_implementations()
總結(jié)
KV緩存和分頁注意力是提升LLM推理效率和可擴展性的重要技術(shù)。KV緩存通過消除冗余計算來優(yōu)化計算過程,而分頁注意力則解決了處理長序列時面臨的內(nèi)存限制。
隨著模型規(guī)模和復(fù)雜性的不斷增長,這些優(yōu)化技術(shù)對于實際應(yīng)用變得至關(guān)重要。深入理解和有效實施這些技術(shù),可以顯著提升LLM部署的性能和效率。