成人免费xxxxx在线视频软件_久久精品久久久_亚洲国产精品久久久_天天色天天色_亚洲人成一区_欧美一级欧美三级在线观看

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐 原創(chuàng)

發(fā)布于 2025-7-17 14:03
瀏覽
0收藏

注意力機(jī)制是 Transformer 架構(gòu)的靈魂,也是大模型性能與效率平衡的關(guān)鍵。從最初的多頭注意力(MHA)到最新的多頭潛在注意力(MLA),研究者們通過(guò)不斷優(yōu)化鍵(Key)、值(Value)與查詢(Query)的交互方式,在模型表達(dá)能力與計(jì)算效率之間持續(xù)探索。本文將系統(tǒng)梳理 MHA、MQA、GQA、MLA 四種主流注意力機(jī)制的理論根基,剖析其設(shè)計(jì)動(dòng)機(jī)、核心原理與代碼實(shí)踐。

一、多頭注意力(MHA):并行特征捕捉的奠基之作

1.1 設(shè)計(jì)動(dòng)機(jī):突破單頭注意力的表達(dá)瓶頸

在 Transformer 提出之前,傳統(tǒng)注意力機(jī)制(如 Bahdanau 注意力)通過(guò)單組 Query、Key、Value 計(jì)算序列依賴,難以同時(shí)捕捉不同維度的特征模式(如語(yǔ)法結(jié)構(gòu)、語(yǔ)義關(guān)聯(lián))。MHA 的核心創(chuàng)新在于將輸入映射到多個(gè)子空間并行計(jì)算注意力,使模型能同時(shí)關(guān)注序列中不同位置的多維度特征,從而增強(qiáng)對(duì)復(fù)雜模式的建模能力。

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)圖片

1.2 核心原理:三分三拆,聚合增強(qiáng)

MHA 的計(jì)算過(guò)程可概括為 “線性變換 - 多頭拆分 - 注意力計(jì)算 - 聚合投影” 四步:

  • 線性變換:輸入序列通過(guò)三個(gè)可學(xué)習(xí)矩陣生成 Query(Q)、Key(K)、Value(V),維度均為(batch_size, seq_len, hidden_size)。
  • 多頭拆分:將 Q、K、V 按頭數(shù)(num_heads)拆分,每個(gè)頭的維度為(head_dim = hidden_size /num_heads),形狀調(diào)整為(batch_size, num_heads, seq_len, head_dim)。
  • 縮放點(diǎn)積注意力:每個(gè)頭獨(dú)立計(jì)算注意力權(quán)重,公式如下,其中根號(hào) d_{k} 為縮放因子,緩解點(diǎn)積過(guò)大導(dǎo)致的梯度消失問(wèn)題。

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

  • 聚合投影:將所有頭的輸出拼接,通過(guò)線性層映射回原始維度(hidden_size)。

1.3 理論優(yōu)勢(shì)與局限

  • 優(yōu)勢(shì):多頭并行機(jī)制使模型能捕捉多尺度特征(如局部句法與全局語(yǔ)義),是大模型強(qiáng)表達(dá)能力的核心來(lái)源。
  • 局限:參數(shù)量與計(jì)算量隨頭數(shù)線性增長(zhǎng)(僅 Q、K、V 的線性層參數(shù)量就達(dá)

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

,且推理時(shí)需緩存所有頭的 K、V,導(dǎo)致 KV 緩存占用過(guò)高(每 token 緩存

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

,限制長(zhǎng)序列與大規(guī)模部署。

二、多查詢注意力(MQA):極致效率的參數(shù)共享方案

2.1 設(shè)計(jì)動(dòng)機(jī):破解 KV 緩存的內(nèi)存瓶頸

MHA 的 KV 緩存隨頭數(shù)線性增長(zhǎng),在長(zhǎng)序列場(chǎng)景(如文檔理解)中極易引發(fā)顯存溢出。MQA 的核心思路是通過(guò)共享 K、V 參數(shù)減少冗余計(jì)算與存儲(chǔ),在犧牲部分表達(dá)能力的前提下?lián)Q取效率提升。

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

2.2 核心原理:?jiǎn)谓M KV,多頭 Query

MQA 對(duì) MHA 的改進(jìn)體現(xiàn)在參數(shù)共享策略:

  • Query 保持多頭獨(dú)立:Q 仍通過(guò)多頭線性層生成,確保每個(gè)頭的查詢能力差異化。
  • Key 與 Value 全局共享:所有頭共享一組 K、V 參數(shù),即 K、V 的線性層輸出維度為(batch_size, seq_len, head_dim),而非 MHA 的(batch_size, seq_len, hidden_size)。
  • 廣播擴(kuò)展:通過(guò)??unsqueeze???與??expand??操作將共享的 K、V 擴(kuò)展到所有頭,實(shí)現(xiàn)多頭注意力計(jì)算。

2.3 理論優(yōu)勢(shì)與局限

  • 優(yōu)勢(shì):參數(shù)量大幅降低(K、V 的線性層參數(shù)量從
  • 大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

  • 降至

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

,KV 緩存量?jī)H為 MHA 的 1/num_heads,推理速度提升顯著。

  • 局限:K、V 的全局共享導(dǎo)致頭間特征區(qū)分度下降,可能損失模型表達(dá)能力(尤其長(zhǎng)序列任務(wù)中對(duì)細(xì)微差異的捕捉能力)。

三、分組查詢注意力(GQA):性能與效率的折中之道

3.1 設(shè)計(jì)動(dòng)機(jī):平衡表達(dá)與效率的中間方案

MQA 雖高效但損失過(guò)多表達(dá)能力,MHA 雖強(qiáng)但成本過(guò)高。GQA 通過(guò)分組共享 KV 參數(shù),在兩者間找到平衡:將 Query 頭劃分為若干組,每組共享一組 K、V,既減少冗余又保留一定的頭間差異。

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

3.2 核心原理:分組共享,局部獨(dú)立

  • 分組策略:設(shè)總頭數(shù)為num_heads,每組包含group_size個(gè)頭,則組數(shù)為num_groups = num_heads / group_size
  • KV 按組生成:K、V 的線性層輸出維度為(batch_size,seq_len,num_groups,head_dim),每組對(duì)應(yīng)一組獨(dú)立的 K、V。
  • 擴(kuò)展計(jì)算:通過(guò)??unsqueeze???與??expand??將每組 K、V 擴(kuò)展到組內(nèi)所有頭,實(shí)現(xiàn)分組注意力計(jì)算。

3.3 理論優(yōu)勢(shì)與局限

  • 優(yōu)勢(shì):參數(shù)量與 KV 緩存量為 MHA 的num_groups / num_heads(如 8 頭分為 4 組,成本降至 50%),同時(shí)保留組間差異,表達(dá)能力優(yōu)于 MQA。
  • 局限:性能依賴分組大小group_size 的選擇,過(guò)小則接近 MHA(效率低),過(guò)大則接近 MQA(表達(dá)弱),需根據(jù)任務(wù)調(diào)優(yōu)。

四、多頭潛在注意力(MLA):低秩壓縮與位置解耦的創(chuàng)新融合

4.1 設(shè)計(jì)動(dòng)機(jī):低秩分解與位置編碼的協(xié)同優(yōu)化

MHA、MQA、GQA 均未突破 “顯式生成 K、V 并緩存” 的范式,而 MLA 通過(guò)低秩參數(shù)化壓縮 KV 維度,并解耦內(nèi)容與位置信息,實(shí)現(xiàn)效率與性能的雙重突破,尤其適合長(zhǎng)序列與大規(guī)模模型部署。

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

4.2 核心原理:低秩壓縮 + 位置解耦,雙線并行

MLA 的創(chuàng)新體現(xiàn)在兩個(gè)關(guān)鍵設(shè)計(jì):

大模型注意力機(jī)制:MHA GQA MQA MLA理論與實(shí)踐-AI.x社區(qū)

4.3 理論優(yōu)勢(shì)與局限

  • 優(yōu)勢(shì):參數(shù)量(約為 MHA 的 42%)與 KV 緩存量(壓縮至 GQA 的 1/5~1/10)大幅降低,同時(shí)通過(guò)低秩分解保留關(guān)鍵特征,表達(dá)能力接近 MHA。
  • 局限:低秩投影與位置解耦增加了模型復(fù)雜度,實(shí)現(xiàn)難度高于前三種機(jī)制,且需針對(duì)性優(yōu)化矩陣合并(如 “吸收” 操作)以避免計(jì)算冗余。

五、四種機(jī)制的理論對(duì)比:從參數(shù)到能力的全面權(quán)衡

機(jī)制

核心創(chuàng)新

參數(shù)量(相對(duì)值)

KV 緩存量(相對(duì)值)

表達(dá)能力

適用場(chǎng)景

MHA

多頭并行

1.0

1.0

最強(qiáng)

預(yù)訓(xùn)練、高性能需求

MQA

全局共享 KV

0.56

1/num_heads

較弱

邊緣部署、高并發(fā)推理

GQA

分組共享 KV

0.75

num_groups/num_heads

較強(qiáng)

通用大模型、平衡需求

MLA

低秩壓縮 + 位置解耦

0.42

~0.1

強(qiáng)

長(zhǎng)序列、大規(guī)模部署

六、實(shí)踐代碼實(shí)現(xiàn)

下面代碼來(lái)自:??https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA?? 侵刪

6.1 多頭注意力(MHA)實(shí)現(xiàn)

import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.0):
        super(MultiHeadAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必須能被 num_heads 整除"


        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads # 每個(gè)頭的維度


        # 定義Q、K、V的線性變換層
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)


        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)


    def forward(self, hidden_state, attention_mask=None):
        batch_size, seq_len, _ = hidden_state.size()


        # 生成Q、K、V
        query = self.query(hidden_state) # [batch_size, seq_len, hidden_size]
        key = self.key(hidden_state) # [batch_size, seq_len, hidden_size]
        value = self.value(hidden_state) # [batch_size, seq_len, hidden_size]


        # 拆分多頭
        # [batch_size, num_heads, seq_len, head_dim]
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)


        # 計(jì)算注意力權(quán)重
        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)


        # 應(yīng)用掩碼
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))


        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)


        # 計(jì)算上下文向量
        context = torch.matmul(attention_weights, value)


        # 合并多頭 # [batch_size, seq_len, hidden_size]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)


        # 輸出投影 # [batch_size, seq_len, hidden_size]
        output = self.out_projection(context)
        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8


    mha = MultiHeadAttention(hidden_size, num_heads)
    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    attention_mask = torch.ones(batch_size, seq_len)
    attention_mask[:, 5:] = 0  # 屏蔽后5個(gè)位置


    output = mha(hidden_state, attention_mask)
    print("MHA輸出形狀:", output.shape)  # torch.Size([2, 10, 256])

6.2 多查詢注意力(MQA)實(shí)現(xiàn)

import torch
import torch.nn as nn


class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.0):
        super(MultiQueryAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必須能被 num_heads 整除"


        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads


        # Q保持多頭獨(dú)立,K和V共享
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, self.head_dim)
        self.value = nn.Linear(hidden_size, self.head_dim)


        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)


    def forward(self, hidden_state, attention_mask=None):
        batch_size, seq_len, _ = hidden_state.size()


        # 生成Q、K、V
        # [batch_size, seq_len, hidden_size]
        query = self.query(hidden_state)
        # [batch_size, seq_len, head_dim]
        key = self.key(hidden_state)
        # [batch_size, seq_len, head_dim]
        value = self.value(hidden_state)


        # 拆分Q為多頭
        # [batch_size, num_heads, seq_len, head_dim]
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)


        # 擴(kuò)展K和V到多頭
        # [batch_size, num_heads, seq_len, head_dim]
        key = key.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        # [batch_size, num_heads, seq_len, head_dim]
        value = value.unsqueeze(1).expand(-1, self.num_heads, -1, -1)


        # 計(jì)算注意力權(quán)重 # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)


        # 應(yīng)用掩碼
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))


        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)


        # 計(jì)算上下文向量
        # [batch_size, num_heads, seq_len, head_dim]
        context = torch.matmul(attention_weights, value)


        # 合并多頭 # [batch_size, seq_len, hidden_size]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)


        # 輸出投影 # [batch_size, seq_len, hidden_size]
        output = self.out_projection(context)
        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8


    mqa = MultiQueryAttention(hidden_size, num_heads)
    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    attention_mask = torch.ones(batch_size, seq_len)
    attention_mask[:, 5:] = 0


    output = mqa(hidden_state, attention_mask)
    print("MQA輸出形狀:", output.shape)  # torch.Size([2, 10, 256])

6.3 分組查詢注意力(GQA)實(shí)現(xiàn)

import torch
import torch.nn as nn


class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0):
        super(GroupedQueryAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必須能被 num_heads 整除"
        assert num_heads % group_size == 0, "num_heads 必須能被 group_size 整除"


        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.group_size = group_size
        self.group_num = num_heads // group_size
        self.head_dim = hidden_size // num_heads


        # Q保持多頭獨(dú)立,K和V按組共享
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, self.group_num * self.head_dim)
        self.value = nn.Linear(hidden_size, self.group_num * self.head_dim)


        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)


    def forward(self, hidden_state, attention_mask=None):
        batch_size, seq_len, _ = hidden_state.size()


        # 生成Q、K、V
        # [batch_size, seq_len, hidden_size]
        query = self.query(hidden_state)
        # [batch_size, seq_len, group_num * head_dim]
        key = self.key(hidden_state)
        # [batch_size, seq_len, group_num * head_dim]
        value = self.value(hidden_state)


        # 拆分Q為多頭
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)


        # 拆分K和V為組并擴(kuò)展到多頭 
        # [batch_size, group_num, seq_len, head_dim]
        key = key.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        key = key.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)


        # [batch_size, group_num, seq_len, head_dim]
        value = value.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        value = value.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)


        # 計(jì)算注意力權(quán)重
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)


        # 應(yīng)用掩碼
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))


        attention_weights = torch.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)


        # 計(jì)算上下文向量
        context = torch.matmul(attention_weights, value)


        # 合并多頭
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)


        # 輸出投影
        output = self.out_projection(context)
        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8
    group_size = 2  # 每組2個(gè)頭,共4組


    gqa = GroupedQueryAttention(hidden_size, num_heads, group_size)
    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    attention_mask = torch.ones(batch_size, seq_len)
    attention_mask[:, 5:] = 0


    output = gqa(hidden_state, attention_mask)
    print("GQA輸出形狀:", output.shape)  # torch.Size([2, 10, 256])

6.4 多頭潛在注意力(MLA)實(shí)現(xiàn)

import torch
import torch.nn as nn
import math


class RotaryEmbedding(nn.Module):
    def __init__(self, hidden_size, num_heads, base=10000, max_len=512):
        super().__init__()
        self.head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.base = base
        self.max_len = max_len
        self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()


    def _compute_pos_emb(self):
        theta_i = 1. / (self.base **(torch.arange(0, self.head_dim, 2).float() / self.head_dim))
        positions = torch.arange(self.max_len)
        pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)


        cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
        sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)


        return cos_pos, sin_pos


    def forward(self, q):
        bs, seq_len = q.shape[0], q.shape[2]
        # [seq_len, head_dim]
        cos_pos = self.cos_pos_cache[:seq_len].to(q.device)
        # [seq_len, head_dim]
        sin_pos = self.sin_pos_cache[:seq_len].to(q.device)


        # [1, 1, seq_len, head_dim]
        cos_pos = cos_pos.unsqueeze(0).unsqueeze(0)
        # [1, 1, seq_len, head_dim]
        sin_pos = sin_pos.unsqueeze(0).unsqueeze(0)


        # RoPE變換
        q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
        q2 = q2.reshape(q.shape).contiguous()


        return q * cos_pos + q2 * sin_pos


class MultiHeadLatentAttention(nn.Module):
    def __init__(self, hidden_size=256, down_dim=64, up_dim=128, num_heads=8, rope_head_dim=26, dropout_prob=0.0):
        super(MultiHeadLatentAttention, self).__init__()
        self.d_model = hidden_size
        self.down_dim = down_dim
        self.up_dim = up_dim
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.rope_head_dim = rope_head_dim
        self.v_head_dim = up_dim // num_heads


        # 降維投影
        self.down_proj_kv = nn.Linear(hidden_size, down_dim)
        self.down_proj_q = nn.Linear(hidden_size, down_dim)


        # 升維投影
        self.up_proj_k = nn.Linear(down_dim, up_dim)
        self.up_proj_v = nn.Linear(down_dim, up_dim)
        self.up_proj_q = nn.Linear(down_dim, up_dim)


        # 解耦Q/K投影
        self.proj_qr = nn.Linear(down_dim, rope_head_dim * num_heads)
        self.proj_kr = nn.Linear(hidden_size, rope_head_dim)


        # RoPE位置編碼
        self.rope_q = RotaryEmbedding(rope_head_dim * num_heads, num_heads)
        self.rope_k = RotaryEmbedding(rope_head_dim, 1)


        # 輸出層
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(num_heads * self.v_head_dim, hidden_size)
        self.res_dropout = nn.Dropout(dropout_prob)


    def forward(self, h, mask=None):
        bs, seq_len, _ = h.size()


        # 低秩轉(zhuǎn)換
        # [bs, seq_len, down_dim]
        c_t_kv = self.down_proj_kv(h)
        # [bs, seq_len, up_dim]
        k_t_c = self.up_proj_k(c_t_kv)
        # [bs, seq_len, up_dim]
        v_t_c = self.up_proj_v(c_t_kv)
        # [bs, seq_len, down_dim]
        c_t_q = self.down_proj_q(h)
        # [bs, seq_len, up_dim]
        q_t_c = self.up_proj_q(c_t_q)


        # 解耦Q/K處理
        # [bs, seq_len, rope_head_dim*num_heads]
        q_t_r = self.proj_qr(c_t_q)
        # [bs, num_heads, seq_len, rope_head_dim]
        q_t_r = q_t_r.view(bs, seq_len, self.num_heads, self.rope_head_dim).transpose(1, 2)
        # RoPE投影處理
        q_t_r = self.rope_q(q_t_r)


        # [bs, seq_len, rope_head_dim]
        k_t_r = self.proj_kr(h)
        # [bs, 1, seq_len, rope_head_dim]
        k_t_r = k_t_r.unsqueeze(1)
        # 應(yīng)用RoPE編碼
        k_t_r = self.rope_k(k_t_r)


        # 注意力計(jì)算
        # [bs, num_heads, seq_len, up_dim/num_heads]
        q_t_c = q_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        # [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
        q = torch.cat([q_t_c, q_t_r], dim=-1)


        # [bs, num_heads, seq_len, up_dim/num_heads]
        k_t_c = k_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        # [bs, num_heads, seq_len, rope_head_dim]
        k_t_r = k_t_r.expand(bs, self.num_heads, seq_len, -1)
        # [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
        k = torch.cat([k_t_c, k_t_r], dim=-1)


        # [bs, num_heads, seq_len, seq_len]
        scores = torch.matmul(q, k.transpose(-1, -2))
        scores = scores / (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim))


        if mask is not None:
            # [bs, num_heads, seq_len, seq_len]
            scores = scores.masked_fill(mask[:, None, None, :] == 0, float('-inf'))


        # [bs, num_heads, seq_len, seq_len]
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)


        # V維度調(diào)整 # [bs, num_heads, seq_len, v_head_dim]
        v_t_c = v_t_c.view(bs, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)


        # 計(jì)算上下文向量
        # [bs, num_heads, seq_len, v_head_dim]
        context = torch.matmul(attn_weights, v_t_c)


        # 合并多頭 # [bs, seq_len, num_heads*v_head_dim]
        context = context.transpose(1, 2).contiguous().view(bs, seq_len, -1)


        # 輸出投影
        # [bs, seq_len, d_model]
        output = self.fc(context)
        output = self.res_dropout(output)


        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256


    h = torch.randn(batch_size, seq_len, hidden_size)
    mla = MultiHeadLatentAttention(hidden_size=hidden_size)


    mask = torch.ones(batch_size, seq_len)
    mask[:, 5:] = 0


    output = mla(h, mask)
    print("MLA輸出形狀:", output.shape)  # torch.Size([2, 10, 256])

七、總結(jié)與建議

從 MHA 到 MLA 的演進(jìn),本質(zhì)是 “表達(dá)能力 - 計(jì)算效率” 的權(quán)衡藝術(shù):MHA 奠定了多頭并行的基礎(chǔ),MQA 與 GQA 通過(guò)參數(shù)共享優(yōu)化效率,MLA 則通過(guò)低秩分解與位置解耦實(shí)現(xiàn)了質(zhì)的突破。

四種注意力機(jī)制各有優(yōu)劣,在實(shí)際應(yīng)用中需根據(jù)具體場(chǎng)景選擇:

  • MHA:適用于對(duì)性能要求高、資源充足的場(chǎng)景,如預(yù)訓(xùn)練階段。
  • MQA:適用于資源受限、對(duì)推理速度要求高的場(chǎng)景,如邊緣設(shè)備部署。
  • GQA:大多數(shù)情況下的優(yōu)選,在性能與效率間取得平衡,適合通用大模型。
  • MLA:適用于長(zhǎng)序列任務(wù)和大規(guī)模模型部署,在顯存有限的情況下表現(xiàn)出色。

隨著大模型向更大參數(shù)量、更長(zhǎng)序列發(fā)展,注意力機(jī)制的優(yōu)化將持續(xù)推進(jìn)。開(kāi)發(fā)者應(yīng)根據(jù)實(shí)際需求選擇合適的機(jī)制,并關(guān)注最新研究進(jìn)展,不斷提升模型的性能與效率。

參考文獻(xiàn)

  1. 宋志學(xué),《手撕大模型 Attention:MLA、MHA、MQA 與 GQA (含實(shí)現(xiàn)代碼)》,??https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA??,2025-05-20,微信公眾號(hào)
  2. 蘇劍林,《Transformer 升級(jí)之路:多頭潛在注意力機(jī)制 (MLA) 究竟好在哪里?》,??https://mp.weixin.qq.com/s/KdOjWF4n5gNtQxKKvkG5Mw??,2025-05-22,微信公眾號(hào)
  3. 姜富春,《DeepSeek 技術(shù)解讀 1: 徹底理解 MLA》,??https://mp.weixin.qq.com/s/yL_Z8zcAfWDcviZwApdL_w??,2025-01-15,微信公眾號(hào)
  4. 算法狗,《DeepSeek MLA: 高效推理的省錢之道,全流程剖析》,??https://mp.weixin.qq.com/s/yNxjgQMl2LKzpGOoCWRRcw??,2025-02-19,微信公眾號(hào)
  5. 羽說(shuō) AI 研究圈,《從 MHA→MQA→GQA→MLA》,??https://mp.weixin.qq.com/s/S9dfOCrWeru6zGjOjchV7Q??,2025-02-12,微信公眾號(hào)


本文轉(zhuǎn)載自??鴻煊的學(xué)習(xí)筆記??,作者:乘風(fēng)破浪jxj

?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請(qǐng)注明出處,否則將追究法律責(zé)任
已于2025-7-17 14:14:47修改
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦
主站蜘蛛池模板: 天天操天天干天天爽 | 日本69视频 | 亚洲在线视频 | 免费在线观看黄 | 日韩综合精品 | 特级丰满少妇一级aaaa爱毛片 | 精品国产91 | 色综合久久88 | 国产成人综合网 | 可以免费看的毛片 | 黄色片免费看 | 免费特级毛片 | 日韩一级黄 | 精品国产aⅴ麻豆 | 国产精品视频网 | 亚洲高清在线播放 | 亚洲日本久久 | 黄色国产 | 成人h片在线观看 | 日韩三级视频 | 日韩特级毛片 | 欧美色图一区二区三区 | 九一国产精品 | 麻豆亚洲一区 | 久久久精品一区二区三区 | 国产麻豆一区二区三区 | 亚洲久久久 | 日韩中文视频 | 美日韩一区二区 | 天天爽天天爽 | 国产精品呻吟 | 逼逼操| 日韩伦理一区 | 国产精品一二 | 亚洲午夜久久 | 夜夜操天天操 | 手机av在线免费观看 | 99香蕉视频| 99热国产在线 | 精品久久影院 | 欧美日韩国产精品 |