大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐 原創(chuàng)
注意力機(jī)制是 Transformer 架構(gòu)的靈魂,也是大模型性能與效率平衡的關(guān)鍵。從最初的多頭注意力(MHA)到最新的多頭潛在注意力(MLA),研究者們通過(guò)不斷優(yōu)化鍵(Key)、值(Value)與查詢(Query)的交互方式,在模型表達(dá)能力與計(jì)算效率之間持續(xù)探索。本文將系統(tǒng)梳理 MHA、MQA、GQA、MLA 四種主流注意力機(jī)制的理論根基,剖析其設(shè)計(jì)動(dòng)機(jī)、核心原理與代碼實(shí)踐。
一、多頭注意力(MHA):并行特征捕捉的奠基之作
1.1 設(shè)計(jì)動(dòng)機(jī):突破單頭注意力的表達(dá)瓶頸
在 Transformer 提出之前,傳統(tǒng)注意力機(jī)制(如 Bahdanau 注意力)通過(guò)單組 Query、Key、Value 計(jì)算序列依賴,難以同時(shí)捕捉不同維度的特征模式(如語(yǔ)法結(jié)構(gòu)、語(yǔ)義關(guān)聯(lián))。MHA 的核心創(chuàng)新在于將輸入映射到多個(gè)子空間并行計(jì)算注意力,使模型能同時(shí)關(guān)注序列中不同位置的多維度特征,從而增強(qiáng)對(duì)復(fù)雜模式的建模能力。
圖片
1.2 核心原理:三分三拆,聚合增強(qiáng)
MHA 的計(jì)算過(guò)程可概括為 “線性變換 - 多頭拆分 - 注意力計(jì)算 - 聚合投影” 四步:
- 線性變換:輸入序列通過(guò)三個(gè)可學(xué)習(xí)矩陣生成 Query(Q)、Key(K)、Value(V),維度均為(batch_size, seq_len, hidden_size)。
- 多頭拆分:將 Q、K、V 按頭數(shù)(num_heads)拆分,每個(gè)頭的維度為(head_dim = hidden_size /num_heads),形狀調(diào)整為(batch_size, num_heads, seq_len, head_dim)。
- 縮放點(diǎn)積注意力:每個(gè)頭獨(dú)立計(jì)算注意力權(quán)重,公式如下,其中根號(hào) d_{k} 為縮放因子,緩解點(diǎn)積過(guò)大導(dǎo)致的梯度消失問(wèn)題。
- 聚合投影:將所有頭的輸出拼接,通過(guò)線性層映射回原始維度(hidden_size)。
1.3 理論優(yōu)勢(shì)與局限
- 優(yōu)勢(shì):多頭并行機(jī)制使模型能捕捉多尺度特征(如局部句法與全局語(yǔ)義),是大模型強(qiáng)表達(dá)能力的核心來(lái)源。
- 局限:參數(shù)量與計(jì)算量隨頭數(shù)線性增長(zhǎng)(僅 Q、K、V 的線性層參數(shù)量就達(dá)
,且推理時(shí)需緩存所有頭的 K、V,導(dǎo)致 KV 緩存占用過(guò)高(每 token 緩存
,限制長(zhǎng)序列與大規(guī)模部署。
二、多查詢注意力(MQA):極致效率的參數(shù)共享方案
2.1 設(shè)計(jì)動(dòng)機(jī):破解 KV 緩存的內(nèi)存瓶頸
MHA 的 KV 緩存隨頭數(shù)線性增長(zhǎng),在長(zhǎng)序列場(chǎng)景(如文檔理解)中極易引發(fā)顯存溢出。MQA 的核心思路是通過(guò)共享 K、V 參數(shù)減少冗余計(jì)算與存儲(chǔ),在犧牲部分表達(dá)能力的前提下?lián)Q取效率提升。
2.2 核心原理:?jiǎn)谓M KV,多頭 Query
MQA 對(duì) MHA 的改進(jìn)體現(xiàn)在參數(shù)共享策略:
- Query 保持多頭獨(dú)立:Q 仍通過(guò)多頭線性層生成,確保每個(gè)頭的查詢能力差異化。
- Key 與 Value 全局共享:所有頭共享一組 K、V 參數(shù),即 K、V 的線性層輸出維度為(batch_size, seq_len, head_dim),而非 MHA 的(batch_size, seq_len, hidden_size)。
- 廣播擴(kuò)展:通過(guò)?
?unsqueeze?
??與??expand?
?操作將共享的 K、V 擴(kuò)展到所有頭,實(shí)現(xiàn)多頭注意力計(jì)算。
2.3 理論優(yōu)勢(shì)與局限
- 優(yōu)勢(shì):參數(shù)量大幅降低(K、V 的線性層參數(shù)量從
- 降至
,KV 緩存量?jī)H為 MHA 的 1/num_heads,推理速度提升顯著。
- 局限:K、V 的全局共享導(dǎo)致頭間特征區(qū)分度下降,可能損失模型表達(dá)能力(尤其長(zhǎng)序列任務(wù)中對(duì)細(xì)微差異的捕捉能力)。
三、分組查詢注意力(GQA):性能與效率的折中之道
3.1 設(shè)計(jì)動(dòng)機(jī):平衡表達(dá)與效率的中間方案
MQA 雖高效但損失過(guò)多表達(dá)能力,MHA 雖強(qiáng)但成本過(guò)高。GQA 通過(guò)分組共享 KV 參數(shù),在兩者間找到平衡:將 Query 頭劃分為若干組,每組共享一組 K、V,既減少冗余又保留一定的頭間差異。
3.2 核心原理:分組共享,局部獨(dú)立
- 分組策略:設(shè)總頭數(shù)為num_heads,每組包含group_size個(gè)頭,則組數(shù)為num_groups = num_heads / group_size。
- KV 按組生成:K、V 的線性層輸出維度為(batch_size,seq_len,num_groups,head_dim),每組對(duì)應(yīng)一組獨(dú)立的 K、V。
- 擴(kuò)展計(jì)算:通過(guò)?
?unsqueeze?
??與??expand?
?將每組 K、V 擴(kuò)展到組內(nèi)所有頭,實(shí)現(xiàn)分組注意力計(jì)算。
3.3 理論優(yōu)勢(shì)與局限
- 優(yōu)勢(shì):參數(shù)量與 KV 緩存量為 MHA 的num_groups / num_heads(如 8 頭分為 4 組,成本降至 50%),同時(shí)保留組間差異,表達(dá)能力優(yōu)于 MQA。
- 局限:性能依賴分組大小group_size 的選擇,過(guò)小則接近 MHA(效率低),過(guò)大則接近 MQA(表達(dá)弱),需根據(jù)任務(wù)調(diào)優(yōu)。
四、多頭潛在注意力(MLA):低秩壓縮與位置解耦的創(chuàng)新融合
4.1 設(shè)計(jì)動(dòng)機(jī):低秩分解與位置編碼的協(xié)同優(yōu)化
MHA、MQA、GQA 均未突破 “顯式生成 K、V 并緩存” 的范式,而 MLA 通過(guò)低秩參數(shù)化壓縮 KV 維度,并解耦內(nèi)容與位置信息,實(shí)現(xiàn)效率與性能的雙重突破,尤其適合長(zhǎng)序列與大規(guī)模模型部署。
4.2 核心原理:低秩壓縮 + 位置解耦,雙線并行
MLA 的創(chuàng)新體現(xiàn)在兩個(gè)關(guān)鍵設(shè)計(jì):
4.3 理論優(yōu)勢(shì)與局限
- 優(yōu)勢(shì):參數(shù)量(約為 MHA 的 42%)與 KV 緩存量(壓縮至 GQA 的 1/5~1/10)大幅降低,同時(shí)通過(guò)低秩分解保留關(guān)鍵特征,表達(dá)能力接近 MHA。
- 局限:低秩投影與位置解耦增加了模型復(fù)雜度,實(shí)現(xiàn)難度高于前三種機(jī)制,且需針對(duì)性優(yōu)化矩陣合并(如 “吸收” 操作)以避免計(jì)算冗余。
五、四種機(jī)制的理論對(duì)比:從參數(shù)到能力的全面權(quán)衡
機(jī)制 | 核心創(chuàng)新 | 參數(shù)量(相對(duì)值) | KV 緩存量(相對(duì)值) | 表達(dá)能力 | 適用場(chǎng)景 |
MHA | 多頭并行 | 1.0 | 1.0 | 最強(qiáng) | 預(yù)訓(xùn)練、高性能需求 |
MQA | 全局共享 KV | 0.56 | 1/num_heads | 較弱 | 邊緣部署、高并發(fā)推理 |
GQA | 分組共享 KV | 0.75 | num_groups/num_heads | 較強(qiáng) | 通用大模型、平衡需求 |
MLA | 低秩壓縮 + 位置解耦 | 0.42 | ~0.1 | 強(qiáng) | 長(zhǎng)序列、大規(guī)模部署 |
六、實(shí)踐代碼實(shí)現(xiàn)
下面代碼來(lái)自:??https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA?? 侵刪
6.1 多頭注意力(MHA)實(shí)現(xiàn)
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.0):
super(MultiHeadAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必須能被 num_heads 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads # 每個(gè)頭的維度
# 定義Q、K、V的線性變換層
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None):
batch_size, seq_len, _ = hidden_state.size()
# 生成Q、K、V
query = self.query(hidden_state) # [batch_size, seq_len, hidden_size]
key = self.key(hidden_state) # [batch_size, seq_len, hidden_size]
value = self.value(hidden_state) # [batch_size, seq_len, hidden_size]
# 拆分多頭
# [batch_size, num_heads, seq_len, head_dim]
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 計(jì)算注意力權(quán)重
# [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)
# 應(yīng)用掩碼
if attention_mask is not None:
attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
# [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.softmax(attention_weights, dim=-1)
attention_weights = self.dropout(attention_weights)
# 計(jì)算上下文向量
context = torch.matmul(attention_weights, value)
# 合并多頭 # [batch_size, seq_len, hidden_size]
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# 輸出投影 # [batch_size, seq_len, hidden_size]
output = self.out_projection(context)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 8
mha = MultiHeadAttention(hidden_size, num_heads)
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 5:] = 0 # 屏蔽后5個(gè)位置
output = mha(hidden_state, attention_mask)
print("MHA輸出形狀:", output.shape) # torch.Size([2, 10, 256])
6.2 多查詢注意力(MQA)實(shí)現(xiàn)
import torch
import torch.nn as nn
class MultiQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.0):
super(MultiQueryAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必須能被 num_heads 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# Q保持多頭獨(dú)立,K和V共享
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, self.head_dim)
self.value = nn.Linear(hidden_size, self.head_dim)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None):
batch_size, seq_len, _ = hidden_state.size()
# 生成Q、K、V
# [batch_size, seq_len, hidden_size]
query = self.query(hidden_state)
# [batch_size, seq_len, head_dim]
key = self.key(hidden_state)
# [batch_size, seq_len, head_dim]
value = self.value(hidden_state)
# 拆分Q為多頭
# [batch_size, num_heads, seq_len, head_dim]
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 擴(kuò)展K和V到多頭
# [batch_size, num_heads, seq_len, head_dim]
key = key.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
# [batch_size, num_heads, seq_len, head_dim]
value = value.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
# 計(jì)算注意力權(quán)重 # [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)
# 應(yīng)用掩碼
if attention_mask is not None:
attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
# [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.softmax(attention_weights, dim=-1)
attention_weights = self.dropout(attention_weights)
# 計(jì)算上下文向量
# [batch_size, num_heads, seq_len, head_dim]
context = torch.matmul(attention_weights, value)
# 合并多頭 # [batch_size, seq_len, hidden_size]
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# 輸出投影 # [batch_size, seq_len, hidden_size]
output = self.out_projection(context)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 8
mqa = MultiQueryAttention(hidden_size, num_heads)
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 5:] = 0
output = mqa(hidden_state, attention_mask)
print("MQA輸出形狀:", output.shape) # torch.Size([2, 10, 256])
6.3 分組查詢注意力(GQA)實(shí)現(xiàn)
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0):
super(GroupedQueryAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必須能被 num_heads 整除"
assert num_heads % group_size == 0, "num_heads 必須能被 group_size 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.group_size = group_size
self.group_num = num_heads // group_size
self.head_dim = hidden_size // num_heads
# Q保持多頭獨(dú)立,K和V按組共享
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, self.group_num * self.head_dim)
self.value = nn.Linear(hidden_size, self.group_num * self.head_dim)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None):
batch_size, seq_len, _ = hidden_state.size()
# 生成Q、K、V
# [batch_size, seq_len, hidden_size]
query = self.query(hidden_state)
# [batch_size, seq_len, group_num * head_dim]
key = self.key(hidden_state)
# [batch_size, seq_len, group_num * head_dim]
value = self.value(hidden_state)
# 拆分Q為多頭
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 拆分K和V為組并擴(kuò)展到多頭
# [batch_size, group_num, seq_len, head_dim]
key = key.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
key = key.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)
# [batch_size, group_num, seq_len, head_dim]
value = value.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
value = value.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)
# 計(jì)算注意力權(quán)重
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)
# 應(yīng)用掩碼
if attention_mask is not None:
attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
attention_weights = torch.softmax(attention_weights, dim=-1)
attention_weights = self.dropout(attention_weights)
# 計(jì)算上下文向量
context = torch.matmul(attention_weights, value)
# 合并多頭
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# 輸出投影
output = self.out_projection(context)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 8
group_size = 2 # 每組2個(gè)頭,共4組
gqa = GroupedQueryAttention(hidden_size, num_heads, group_size)
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 5:] = 0
output = gqa(hidden_state, attention_mask)
print("GQA輸出形狀:", output.shape) # torch.Size([2, 10, 256])
6.4 多頭潛在注意力(MLA)實(shí)現(xiàn)
import torch
import torch.nn as nn
import math
class RotaryEmbedding(nn.Module):
def __init__(self, hidden_size, num_heads, base=10000, max_len=512):
super().__init__()
self.head_dim = hidden_size // num_heads
self.hidden_size = hidden_size
self.num_heads = num_heads
self.base = base
self.max_len = max_len
self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()
def _compute_pos_emb(self):
theta_i = 1. / (self.base **(torch.arange(0, self.head_dim, 2).float() / self.head_dim))
positions = torch.arange(self.max_len)
pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)
cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)
return cos_pos, sin_pos
def forward(self, q):
bs, seq_len = q.shape[0], q.shape[2]
# [seq_len, head_dim]
cos_pos = self.cos_pos_cache[:seq_len].to(q.device)
# [seq_len, head_dim]
sin_pos = self.sin_pos_cache[:seq_len].to(q.device)
# [1, 1, seq_len, head_dim]
cos_pos = cos_pos.unsqueeze(0).unsqueeze(0)
# [1, 1, seq_len, head_dim]
sin_pos = sin_pos.unsqueeze(0).unsqueeze(0)
# RoPE變換
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape).contiguous()
return q * cos_pos + q2 * sin_pos
class MultiHeadLatentAttention(nn.Module):
def __init__(self, hidden_size=256, down_dim=64, up_dim=128, num_heads=8, rope_head_dim=26, dropout_prob=0.0):
super(MultiHeadLatentAttention, self).__init__()
self.d_model = hidden_size
self.down_dim = down_dim
self.up_dim = up_dim
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.rope_head_dim = rope_head_dim
self.v_head_dim = up_dim // num_heads
# 降維投影
self.down_proj_kv = nn.Linear(hidden_size, down_dim)
self.down_proj_q = nn.Linear(hidden_size, down_dim)
# 升維投影
self.up_proj_k = nn.Linear(down_dim, up_dim)
self.up_proj_v = nn.Linear(down_dim, up_dim)
self.up_proj_q = nn.Linear(down_dim, up_dim)
# 解耦Q/K投影
self.proj_qr = nn.Linear(down_dim, rope_head_dim * num_heads)
self.proj_kr = nn.Linear(hidden_size, rope_head_dim)
# RoPE位置編碼
self.rope_q = RotaryEmbedding(rope_head_dim * num_heads, num_heads)
self.rope_k = RotaryEmbedding(rope_head_dim, 1)
# 輸出層
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(num_heads * self.v_head_dim, hidden_size)
self.res_dropout = nn.Dropout(dropout_prob)
def forward(self, h, mask=None):
bs, seq_len, _ = h.size()
# 低秩轉(zhuǎn)換
# [bs, seq_len, down_dim]
c_t_kv = self.down_proj_kv(h)
# [bs, seq_len, up_dim]
k_t_c = self.up_proj_k(c_t_kv)
# [bs, seq_len, up_dim]
v_t_c = self.up_proj_v(c_t_kv)
# [bs, seq_len, down_dim]
c_t_q = self.down_proj_q(h)
# [bs, seq_len, up_dim]
q_t_c = self.up_proj_q(c_t_q)
# 解耦Q/K處理
# [bs, seq_len, rope_head_dim*num_heads]
q_t_r = self.proj_qr(c_t_q)
# [bs, num_heads, seq_len, rope_head_dim]
q_t_r = q_t_r.view(bs, seq_len, self.num_heads, self.rope_head_dim).transpose(1, 2)
# RoPE投影處理
q_t_r = self.rope_q(q_t_r)
# [bs, seq_len, rope_head_dim]
k_t_r = self.proj_kr(h)
# [bs, 1, seq_len, rope_head_dim]
k_t_r = k_t_r.unsqueeze(1)
# 應(yīng)用RoPE編碼
k_t_r = self.rope_k(k_t_r)
# 注意力計(jì)算
# [bs, num_heads, seq_len, up_dim/num_heads]
q_t_c = q_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
# [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
q = torch.cat([q_t_c, q_t_r], dim=-1)
# [bs, num_heads, seq_len, up_dim/num_heads]
k_t_c = k_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
# [bs, num_heads, seq_len, rope_head_dim]
k_t_r = k_t_r.expand(bs, self.num_heads, seq_len, -1)
# [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
k = torch.cat([k_t_c, k_t_r], dim=-1)
# [bs, num_heads, seq_len, seq_len]
scores = torch.matmul(q, k.transpose(-1, -2))
scores = scores / (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim))
if mask is not None:
# [bs, num_heads, seq_len, seq_len]
scores = scores.masked_fill(mask[:, None, None, :] == 0, float('-inf'))
# [bs, num_heads, seq_len, seq_len]
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# V維度調(diào)整 # [bs, num_heads, seq_len, v_head_dim]
v_t_c = v_t_c.view(bs, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)
# 計(jì)算上下文向量
# [bs, num_heads, seq_len, v_head_dim]
context = torch.matmul(attn_weights, v_t_c)
# 合并多頭 # [bs, seq_len, num_heads*v_head_dim]
context = context.transpose(1, 2).contiguous().view(bs, seq_len, -1)
# 輸出投影
# [bs, seq_len, d_model]
output = self.fc(context)
output = self.res_dropout(output)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
h = torch.randn(batch_size, seq_len, hidden_size)
mla = MultiHeadLatentAttention(hidden_size=hidden_size)
mask = torch.ones(batch_size, seq_len)
mask[:, 5:] = 0
output = mla(h, mask)
print("MLA輸出形狀:", output.shape) # torch.Size([2, 10, 256])
七、總結(jié)與建議
從 MHA 到 MLA 的演進(jìn),本質(zhì)是 “表達(dá)能力 - 計(jì)算效率” 的權(quán)衡藝術(shù):MHA 奠定了多頭并行的基礎(chǔ),MQA 與 GQA 通過(guò)參數(shù)共享優(yōu)化效率,MLA 則通過(guò)低秩分解與位置解耦實(shí)現(xiàn)了質(zhì)的突破。
四種注意力機(jī)制各有優(yōu)劣,在實(shí)際應(yīng)用中需根據(jù)具體場(chǎng)景選擇:
- MHA:適用于對(duì)性能要求高、資源充足的場(chǎng)景,如預(yù)訓(xùn)練階段。
- MQA:適用于資源受限、對(duì)推理速度要求高的場(chǎng)景,如邊緣設(shè)備部署。
- GQA:大多數(shù)情況下的優(yōu)選,在性能與效率間取得平衡,適合通用大模型。
- MLA:適用于長(zhǎng)序列任務(wù)和大規(guī)模模型部署,在顯存有限的情況下表現(xiàn)出色。
隨著大模型向更大參數(shù)量、更長(zhǎng)序列發(fā)展,注意力機(jī)制的優(yōu)化將持續(xù)推進(jìn)。開(kāi)發(fā)者應(yīng)根據(jù)實(shí)際需求選擇合適的機(jī)制,并關(guān)注最新研究進(jìn)展,不斷提升模型的性能與效率。
參考文獻(xiàn)
- 宋志學(xué),《手撕大模型 Attention:MLA、MHA、MQA 與 GQA (含實(shí)現(xiàn)代碼)》,??https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA??,2025-05-20,微信公眾號(hào)
- 蘇劍林,《Transformer 升級(jí)之路:多頭潛在注意力機(jī)制 (MLA) 究竟好在哪里?》,??https://mp.weixin.qq.com/s/KdOjWF4n5gNtQxKKvkG5Mw??,2025-05-22,微信公眾號(hào)
- 姜富春,《DeepSeek 技術(shù)解讀 1: 徹底理解 MLA》,??https://mp.weixin.qq.com/s/yL_Z8zcAfWDcviZwApdL_w??,2025-01-15,微信公眾號(hào)
- 算法狗,《DeepSeek MLA: 高效推理的省錢之道,全流程剖析》,??https://mp.weixin.qq.com/s/yNxjgQMl2LKzpGOoCWRRcw??,2025-02-19,微信公眾號(hào)
- 羽說(shuō) AI 研究圈,《從 MHA→MQA→GQA→MLA》,??https://mp.weixin.qq.com/s/S9dfOCrWeru6zGjOjchV7Q??,2025-02-12,微信公眾號(hào)
本文轉(zhuǎn)載自??鴻煊的學(xué)習(xí)筆記??,作者:乘風(fēng)破浪jxj
