一文詳解MHA、GQA、MQA原理 原創
前言
本文回顧一下MHA、GQA、MQA,詳細解讀下MHA、GQA、MQA這三種常見注意力機制的原理。
圖1 MHA、GQA、MQA一覽
self-attention
self-attention
在自注意力機制中,輸入通常是一個統一的輸入矩陣,而這個矩陣后續會通過乘以不同的權重矩陣來轉換成三個不同的向量集合:查詢向量Q、鍵向量K和值向量V。這三組向量是通過線性變換方式生成:
1.查詢向量 (Q): Q=XWQ
2.鍵向量 (K): K=XWK
3.值向量 (V): V=XWV
WQ ,WK和WV是可學習的權重矩陣,分別對應于查詢、鍵和值。這些矩陣的維度取決于模型的設計,通常它們的輸出維度(列數) 是預先定義的,以滿足特定的模型架構要求。 在Transformer模型中,使用不同的權重矩陣WQ ,WK和WV來分別生成查詢向量Q、鍵向量K和值向量V的目的是為了允許模型在不同的表示空間中學習和抽取特征。這樣做增加了模型的靈活性和表達能力,允許模型分別優化用于匹配(Q 和K)和用于輸出信息合成(V)的表示。
在自注意力和多頭注意力機制中,使用
作為縮放因子進行縮放操作是為了防止在計算點積時由于維度較高導致的數值穩定性問題。這里的dk是鍵向量的維度。如果不進行縮放,當dk較大時,點積的結果可能會變得非常大,這會導致在應用softmax函數時產生的梯度非常小。因為softmax函數是通過指數函數計算的,大的輸入值會使得部分輸出接近于1,而其他接近于0,從而導致梯度消失,這會在反向傳播過程中造成梯度非常小,使得學習變得非常緩慢。
通過點積結果除以
,可以調整這些值的范圍,使得它們不會太大。這樣,softmax的輸入在一個合適的范圍內,有助于避免極端的指數運算結果,從而保持數值穩定性和更有效的梯度流。這個操作確保了即使在dk很大的情況下, 注意力機制也能穩定并有效地學習。
代碼實現
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, seq_length):
super(SelfAttention, self).__init__()
self.input_size = seq_length
# 定義三個權重矩陣:Wq、Wk、Wv
self.Wq = nn.Linear(seq_length, seq_length) # 線性變換
self.Wk = nn.Linear(seq_length, seq_length)
self.Wv = nn.Linear(seq_length, seq_length)
def forward(self, input):
# 計算Q,K,V 三個矩陣
q = self.Wq(input)
k = self.Wk(input)
v = self.Wv(input)
# 計算QK^T,即向量之間的相關度
attention_scores = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(self.input_size)))
# 計算向量權重,softmax歸一化
attention_weight = F.softmax(attention_scores, dim=-1)
# 計算輸出
output = torch.matmul(attention_weight, v)
return output
x = torch.randn(2, 3, 4)
Self_Attention = SelfAttention(4) # 傳入輸入向量的維度
output = Self_Attention(x)
print(output.shape)
MHA(多頭注意力)
Transformer 編碼器塊內的縮放點積注意力機制和多頭注意力機制
MHA計算過程
代碼實現
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
self.wk = nn.Linear(embed_dim, embed_dim)
self.wv = nn.Linear(embed_dim, embed_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def mh_split(self, hidden):
batch_size = hidden.shape[0]
x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return x
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 線性變換
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多頭切分
q, k, v = self.mh_split(q), self.mh_split(k), self.mh_split(v)
# 注意力計算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 拼接多頭
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 線性變換
output = self.wo(output)
return output
x = torch.rand(2, 3, 36)
print(x)
output = MultiHeadAttention(36, 6)
y = output(x)
print(y.shape)
MHA 能夠理解輸入不同部分之間的關系。然而,這種復雜性是有代價的——對內存帶寬的需求很大,尤其是在解碼器推理期間。主要問題的關鍵在于內存開銷。在自回歸模型中,每個解碼步驟都需要加載解碼器權重以及所有注意鍵和值。這個過程不僅計算量大,而且內存帶寬也大。隨著模型規模的擴大,這種開銷也會增加,使得擴展變得越來越艱巨。
因此,多查詢注意 (MQA) 應運而生,成為緩解這一瓶頸的解決方案。其理念簡單而有效:使用多個查詢頭,但只使用一個鍵和值頭。這種方法顯著減少了內存負載,提高了推理速度。
MQA(多查詢注意力)
圖2 MHA和MQA的差別
MQA是MHA的一種變體,也是用于自回歸解碼的一種注意力機制。,圖1、圖2很形象的描繪了MHA和MQA的對比,與MHA 不同的是,MQA 讓所有的Head之間共享同樣的一份 K 和 V 矩陣(意味K和V的計算唯一),只讓 Q 保留了原始多頭的性質(每個Head存在不同的轉換),從而大大減少 K 和 V 矩陣的參數量以及KV Cache的顯存占用,以此來達到提升推理速度,但是會帶來精度上的損失。MQA被大量應用于LLM中,如ChatGLM2。
左 - 多頭注意力,中 - 多查詢注意力,右 - 將現有的 MHA 檢查點轉換為 MQA
如何將現有的預訓練多頭注意力模型轉換為多查詢注意力模型 (MQA)?從現有的多頭模型創建多查詢注意力模型涉及兩個步驟:模型結構的轉換和隨后的預訓練。
- 模型結構的轉換:此步驟將多頭模型的結構轉換為多查詢模型。它是通過將原始模型的多個頭的鍵和值的投影矩陣(線性層)合并(均值池化)為鍵和值的單個投影矩陣來實現的。這種均值池化方法被發現比選擇現有鍵和值頭之一或從頭開始初始化新的鍵和值頭更有效。生成的結構具有合并的鍵和值投影,這是多查詢模型的特征。
- 對轉換后的模型進行預訓練:結構轉換后,模型將接受額外的訓練。此訓練不像原始模型訓練那樣廣泛;它只是原始模型訓練步驟的一小部分(表示為 α)。此預訓練階段的目的是讓模型根據其新的簡化注意力機制調整和優化其性能。訓練遵循與原始相同的方法,確保學習動態的一致性。
代碼實現
import torch
import torch.nn as nn
class MultiQuerySelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiQuerySelfAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
# MHA
# self.wk = nn.Linear(embed_dim, embed_dim)
# self.wv = nn.Linear(embed_dim, embed_dim)
# MQA
self.wk = nn.Linear(embed_dim, self.head_dim)
self.wv = nn.Linear(embed_dim, self.head_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def q_h_split(self, hidden, head_num=None):
batch_size, seq_len = hidden.size()[:2]
# q拆分多頭
if head_num == None:
x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
return x
else:
# 這是MQA: 需要拆分k和v,這里面的head_num =1 的
# 最終返回維度(batch_size, 1, seq_len, head_dim)
return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 線性變換
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多頭切分
# 這是MHA的
# q, k ,v = self.split(q), self.split(k), self.split(v)
# 這是MQA的
q, k, v = self.q_h_split(q), self.q_h_split(k, 1), self.q_h_split(v, 1)
# 注意力計算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
print("scores:", scores.shape)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 多頭合并
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 線性變換
output = self.wo(output)
return output
x = torch.rand(3, 12, 512)
atten = MultiQuerySelfAttention(512, 8)
y = atten(x)
print(y.shape)
GQA(分組查詢注意力)
雖然MQA方式大幅減小了參數數量,但是,帶來推理加速的同時會造成模型性能損失,且在訓練過程使得模型變得不穩定(復雜度的降低可能會導致質量下降和訓練不穩定),因此在此基礎上提出了GQA,它將Query進行分組,每個組內共享一組Key、Value。(GQA在LLaMA-2 和 Mistral7B得到應用)
GQA 的數學原理:
分組:在 GQA 中,傳統多頭模型中的查詢頭 (Q) 被分成 G 組。每組分配一個鍵 (K) 和值 (V) 頭。此配置表示為 GQA-G,其中 G 表示組數。
GQA 的特殊情況:
- GQA-1 = MQA:只有一個組(G = 1),GQA 等同于 MQA,因為所有查詢頭只有一個鍵和值頭。
- GQA-H = MHA:當組數等于頭數(G = H)時,GQA 退化為 MHA,每個查詢頭都有其唯一的鍵和值頭。
對每個組中原始頭部的鍵和值投影矩陣進行均值池化,以將MHA模型轉換為 GQA 模型。此技術對組中每個頭部的投影矩陣進行平均,從而為該組生成單個鍵和值投影。
通過利用 GQA,該模型在 MHA 質量和 MQA 速度之間保持平衡。由于鍵值對較少,內存帶寬和數據加載需求被最小化。G 的選擇代表了一種權衡:更多的組(更接近 MHA)可帶來更高的質量但性能較慢,而更少的組(接近 MQA)可提高速度但有犧牲質量的風險。此外,隨著模型規模的擴大,GQA 允許內存帶寬和模型容量按比例減少,與模型規模相對應。相比之下,對于更大的模型,在 MQA 中減少到單個鍵和值頭可能會過于嚴重。
代碼實現
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(GroupedQueryAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
# 這是MHA的
# self.wk = nn.Linear(embed_dim, embed_dim)
# self.wv = nn.Linear(embed_dim, embed_dim)
# 這是MQA的
# self.wk = nn.Linear(embed_dim, self.head_dim)
# self.wv = nn.Linear(embed_dim, self.head_dim)
# 這是GQA的
self.group_num = 4 # 這是4個組
self.wk = nn.Linear(embed_dim, self.group_num * self.head_dim)
self.wv = nn.Linear(embed_dim, self.group_num * self.head_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def split(self, hidden, group_num=None):
batch_size, seq_len = hidden.size()[:2]
# q需要拆分多頭
if group_num == None:
x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
return x
else:
# 這是kv需要拆分的多頭
x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)
x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len,
self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
return x
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 線性變換
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多頭切分
# 這是MHA的
# q, k ,v = self.split(q), self.split(k), self.split(v)
# 這是MQA的
# q, k ,v = self.split(q), self.split(k, 1), self.split(v, 1)
# 這是GQA的
q, k, v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)
# 注意力計算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
print("scores:", scores.shape)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 合并多頭
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 線性變換
output = self.wo(output)
return output
x = torch.ones(3, 12, 512)
atten = GroupedQueryAttention(512, 8)
y = atten(x)
print(y.shape)
參考文獻
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,https://arxiv.org/pdf/2305.13245
- Attention Is All You Need,https://arxiv.org/pdf/1706.03762
- Fast Transformer Decoding: One Write-Head is All You Need,https://arxiv.org/pdf/1911.02150v1
本文轉載自公眾號大模型自然語言處理 作者:余俊暉
原文鏈接:??https://mp.weixin.qq.com/s/72fGm-qYV5DdCGz-bNjuXQ??
