成人免费xxxxx在线视频软件_久久精品久久久_亚洲国产精品久久久_天天色天天色_亚洲人成一区_欧美一级欧美三级在线观看

一文詳解MHA、GQA、MQA原理 原創

發布于 2024-11-14 15:40
瀏覽
0收藏

前言

本文回顧一下MHA、GQA、MQA,詳細解讀下MHA、GQA、MQA這三種常見注意力機制的原理。

一文詳解MHA、GQA、MQA原理-AI.x社區

圖1 MHA、GQA、MQA一覽

self-attention

一文詳解MHA、GQA、MQA原理-AI.x社區

self-attention

在自注意力機制中,輸入通常是一個統一的輸入矩陣,而這個矩陣后續會通過乘以不同的權重矩陣來轉換成三個不同的向量集合:查詢向量Q、鍵向量K和值向量V。這三組向量是通過線性變換方式生成:

1.查詢向量 (Q): Q=XWQ

2.鍵向量 (K): K=XWK

3.值向量 (V): V=XWV

W,WK和WV可學習的權重矩陣,分別對應于查詢、鍵和值。這些矩陣的維度取決于模型的設計,通常它們的輸出維度(列數) 是預先定義的,以滿足特定的模型架構要求。 在Transformer模型中,使用不同的權重矩陣W,WK和WV來分別生成查詢向量Q、鍵向量K和值向量V的目的是為了允許模型在不同的表示空間中學習和抽取特征。這樣做增加了模型的靈活性和表達能力,允許模型分別優化用于匹配(Q 和K)和用于輸出信息合成(V)的表示。

在自注意力和多頭注意力機制中,使用

一文詳解MHA、GQA、MQA原理-AI.x社區

作為縮放因子進行縮放操作是為了防止在計算點積時由于維度較高導致的數值穩定性問題。這里的dk是鍵向量的維度。如果不進行縮放,當dk較大時,點積的結果可能會變得非常大,這會導致在應用softmax函數時產生的梯度非常小。因為softmax函數是通過指數函數計算的,大的輸入值會使得部分輸出接近于1,而其他接近于0,從而導致梯度消失,這會在反向傳播過程中造成梯度非常小,使得學習變得非常緩慢。

通過點積結果除以

一文詳解MHA、GQA、MQA原理-AI.x社區

 ,可以調整這些值的范圍,使得它們不會太大。這樣,softmax的輸入在一個合適的范圍內,有助于避免極端的指數運算結果,從而保持數值穩定性和更有效的梯度流。這個操作確保了即使在dk很大的情況下, 注意力機制也能穩定并有效地學習。

代碼實現

import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
    def __init__(self, seq_length):
        super(SelfAttention, self).__init__()
        self.input_size = seq_length
        # 定義三個權重矩陣:Wq、Wk、Wv
        self.Wq = nn.Linear(seq_length, seq_length)  # 線性變換
        self.Wk = nn.Linear(seq_length, seq_length)
        self.Wv = nn.Linear(seq_length, seq_length)

    def forward(self, input):
        # 計算Q,K,V 三個矩陣
        q = self.Wq(input)
        k = self.Wk(input)
        v = self.Wv(input)

        # 計算QK^T,即向量之間的相關度
        attention_scores = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(self.input_size)))
        # 計算向量權重,softmax歸一化
        attention_weight = F.softmax(attention_scores, dim=-1)
        # 計算輸出
        output = torch.matmul(attention_weight, v)
        return output


x = torch.randn(2, 3, 4)
Self_Attention = SelfAttention(4)  # 傳入輸入向量的維度
output = Self_Attention(x)
print(output.shape)

MHA(多頭注意力)

一文詳解MHA、GQA、MQA原理-AI.x社區

Transformer 編碼器塊內的縮放點積注意力機制和多頭注意力機制

一文詳解MHA、GQA、MQA原理-AI.x社區

MHA計算過程

一文詳解MHA、GQA、MQA原理-AI.x社區

代碼實現

import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.wq = nn.Linear(embed_dim, embed_dim)
        self.wk = nn.Linear(embed_dim, embed_dim)
        self.wv = nn.Linear(embed_dim, embed_dim)
        self.wo = nn.Linear(embed_dim, embed_dim)

    def mh_split(self, hidden):
        batch_size = hidden.shape[0]
        x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        return x

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)

        # 線性變換
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)

        # 多頭切分
        q, k, v = self.mh_split(q), self.mh_split(k), self.mh_split(v)

        # 注意力計算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 拼接多頭
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)

        # 線性變換
        output = self.wo(output)

        return output

x = torch.rand(2, 3, 36)
print(x)
output = MultiHeadAttention(36, 6)
y = output(x)
print(y.shape)

MHA 能夠理解輸入不同部分之間的關系。然而,這種復雜性是有代價的——對內存帶寬的需求很大,尤其是在解碼器推理期間。主要問題的關鍵在于內存開銷。在自回歸模型中,每個解碼步驟都需要加載解碼器權重以及所有注意鍵和值。這個過程不僅計算量大,而且內存帶寬也大。隨著模型規模的擴大,這種開銷也會增加,使得擴展變得越來越艱巨。

因此,多查詢注意 (MQA) 應運而生,成為緩解這一瓶頸的解決方案。其理念簡單而有效:使用多個查詢頭,但只使用一個鍵和值頭。這種方法顯著減少了內存負載,提高了推理速度。

MQA(多查詢注意力)

一文詳解MHA、GQA、MQA原理-AI.x社區

圖2 MHA和MQA的差別

MQA是MHA的一種變體,也是用于自回歸解碼的一種注意力機制。,圖1、圖2很形象的描繪了MHA和MQA的對比,與MHA 不同的是,MQA 讓所有的Head之間共享同樣的一份 K 和 V 矩陣(意味K和V的計算唯一),只讓 Q 保留了原始多頭的性質(每個Head存在不同的轉換),從而大大減少 K 和 V 矩陣的參數量以及KV Cache的顯存占用,以此來達到提升推理速度,但是會帶來精度上的損失。MQA被大量應用于LLM中,如ChatGLM2。

一文詳解MHA、GQA、MQA原理-AI.x社區

左 - 多頭注意力,中 - 多查詢注意力,右 - 將現有的 MHA 檢查點轉換為 MQA

如何將現有的預訓練多頭注意力模型轉換為多查詢注意力模型 (MQA)?從現有的多頭模型創建多查詢注意力模型涉及兩個步驟:模型結構的轉換和隨后的預訓練。

  • 模型結構的轉換:此步驟將多頭模型的結構轉換為多查詢模型。它是通過將原始模型的多個頭的鍵和值的投影矩陣(線性層)合并(均值池化)為鍵和值的單個投影矩陣來實現的。這種均值池化方法被發現比選擇現有鍵和值頭之一或從頭開始初始化新的鍵和值頭更有效。生成的結構具有合并的鍵和值投影,這是多查詢模型的特征。
  • 對轉換后的模型進行預訓練:結構轉換后,模型將接受額外的訓練。此訓練不像原始模型訓練那樣廣泛;它只是原始模型訓練步驟的一小部分(表示為 α)。此預訓練階段的目的是讓模型根據其新的簡化注意力機制調整和優化其性能。訓練遵循與原始相同的方法,確保學習動態的一致性。

代碼實現

import torch
import torch.nn as nn


class MultiQuerySelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiQuerySelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.wq = nn.Linear(embed_dim, embed_dim)

        # MHA
        # self.wk = nn.Linear(embed_dim, embed_dim)
        # self.wv = nn.Linear(embed_dim, embed_dim)

        # MQA
        self.wk = nn.Linear(embed_dim, self.head_dim)
        self.wv = nn.Linear(embed_dim, self.head_dim)
        self.wo = nn.Linear(embed_dim, embed_dim)

    def q_h_split(self, hidden, head_num=None):
        batch_size, seq_len = hidden.size()[:2]
        # q拆分多頭
        if head_num == None:
            x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            return x
        else:
            # 這是MQA: 需要拆分k和v,這里面的head_num =1 的
            # 最終返回維度(batch_size, 1, seq_len, head_dim)
            return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)

        # 線性變換
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)

        # 多頭切分
        # 這是MHA的
        # q, k ,v  = self.split(q), self.split(k), self.split(v)
        # 這是MQA的
        q, k, v = self.q_h_split(q), self.q_h_split(k, 1), self.q_h_split(v, 1)

        # 注意力計算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        print("scores:", scores.shape)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 多頭合并
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        # 線性變換
        output = self.wo(output)
        return output


x = torch.rand(3, 12, 512)
atten = MultiQuerySelfAttention(512, 8)
y = atten(x)
print(y.shape)

GQA(分組查詢注意力)

一文詳解MHA、GQA、MQA原理-AI.x社區

雖然MQA方式大幅減小了參數數量,但是,帶來推理加速的同時會造成模型性能損失,且在訓練過程使得模型變得不穩定(復雜度的降低可能會導致質量下降和訓練不穩定),因此在此基礎上提出了GQA,它將Query進行分組,每個組內共享一組Key、Value。(GQA在LLaMA-2 和 Mistral7B得到應用)

GQA 的數學原理

分組:在 GQA 中,傳統多頭模型中的查詢頭 (Q) 被分成 G 組。每組分配一個鍵 (K) 和值 (V) 頭。此配置表示為 GQA-G,其中 G 表示組數。

GQA 的特殊情況

  • GQA-1 = MQA:只有一個組(G = 1),GQA 等同于 MQA,因為所有查詢頭只有一個鍵和值頭。
  • GQA-H = MHA:當組數等于頭數(G = H)時,GQA 退化為 MHA,每個查詢頭都有其唯一的鍵和值頭。

對每個組中原始頭部的鍵和值投影矩陣進行均值池化,以將MHA模型轉換為 GQA 模型。此技術對組中每個頭部的投影矩陣進行平均,從而為該組生成單個鍵和值投影。

通過利用 GQA,該模型在 MHA 質量和 MQA 速度之間保持平衡。由于鍵值對較少,內存帶寬和數據加載需求被最小化。G 的選擇代表了一種權衡:更多的組(更接近 MHA)可帶來更高的質量但性能較慢,而更少的組(接近 MQA)可提高速度但有犧牲質量的風險。此外,隨著模型規模的擴大,GQA 允許內存帶寬和模型容量按比例減少,與模型規模相對應。相比之下,對于更大的模型,在 MQA 中減少到單個鍵和值頭可能會過于嚴重。

代碼實現

import torch
import torch.nn as nn


class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(GroupedQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.wq = nn.Linear(embed_dim, embed_dim)

        # 這是MHA的
        # self.wk = nn.Linear(embed_dim, embed_dim)
        # self.wv = nn.Linear(embed_dim, embed_dim)

        # 這是MQA的
        # self.wk = nn.Linear(embed_dim, self.head_dim)
        # self.wv = nn.Linear(embed_dim, self.head_dim)

        # 這是GQA的
        self.group_num = 4  # 這是4個組
        self.wk = nn.Linear(embed_dim, self.group_num * self.head_dim)
        self.wv = nn.Linear(embed_dim, self.group_num * self.head_dim)

        self.wo = nn.Linear(embed_dim, embed_dim)

    def split(self, hidden, group_num=None):
        batch_size, seq_len = hidden.size()[:2]
        # q需要拆分多頭
        if group_num == None:
            x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            return x
        else:
            # 這是kv需要拆分的多頭
            x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)
            x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len,
                                           self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
            return x

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)

        # 線性變換
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)

        # 多頭切分
        # 這是MHA的
        # q, k ,v  = self.split(q), self.split(k), self.split(v)
        # 這是MQA的
        # q, k ,v  = self.split(q), self.split(k, 1), self.split(v, 1)
        # 這是GQA的
        q, k, v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)

        # 注意力計算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        print("scores:", scores.shape)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 合并多頭
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)

        # 線性變換
        output = self.wo(output)

        return output


x = torch.ones(3, 12, 512)
atten = GroupedQueryAttention(512, 8)
y = atten(x)
print(y.shape)

參考文獻

  • GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,https://arxiv.org/pdf/2305.13245
  • Attention Is All You Need,https://arxiv.org/pdf/1706.03762
  • Fast Transformer Decoding: One Write-Head is All You Need,https://arxiv.org/pdf/1911.02150v1


本文轉載自公眾號大模型自然語言處理  作者:余俊暉

原文鏈接:??https://mp.weixin.qq.com/s/72fGm-qYV5DdCGz-bNjuXQ??


?著作權歸作者所有,如需轉載,請注明出處,否則將追究法律責任
已于2024-11-28 18:52:23修改
收藏
回復
舉報
回復
相關推薦
主站蜘蛛池模板: 久久亚洲精品国产精品紫薇 | 成人av一区二区亚洲精 | 手机在线一区二区三区 | 色桃网| 久久视频精品在线 | 在线免费观看视频你懂的 | 久久综合一区 | 免费的色网站 | 国产成人一区二区 | 婷婷综合在线 | 久草综合在线 | 亚洲精品日韩综合观看成人91 | 精品一区二区三区在线观看 | 99亚洲精品 | 黄色大片免费播放 | 久久久久国产视频 | 精品国产91 | 中文字幕国产高清 | 国产精品久久久久久久久 | 欧洲免费毛片 | 国产精品18hdxxxⅹ在线 | 欧美一级二级在线观看 | 日韩在线欧美 | 91精品国产色综合久久不卡蜜臀 | 国内精品久久久久 | 一区不卡在线观看 | 久久久精品天堂 | 亚洲最大看片网站 | 免费在线成人 | 欧美一级在线观看 | 国产一级黄色网 | 在线看一区二区三区 | 极品在线 | 毛片视频网站 | 羞羞视频网站免费看 | 91热在线 | 刘亦菲国产毛片bd | 国产精品一区二区久久 | 久久久在线视频 | 人人干人人爽 | www.啪啪.com|