機器學習|從0開始大模型之位置編碼
1、什么是位置編碼
在語言中,一句話是由詞組成的,詞與詞之間是有順序的,如果順序亂了或者重排,其實整個句子的意思就變了,所以詞與詞之間是有順序的。在循環神經網絡中,序列與序列之間也是有順序的,所以循環神經網絡中,序列與序列之間也是有順序的,不需要處理這種問題。但是在Transformer中,每個詞是獨立的,所以需要將詞的位置信息添加到模型中,讓模型維護順序關系。
位置編碼
位置編碼就是將hello world! 的token和位置關系通過向量表示出來,作為訓練的輸入數據,如上圖,位置編碼最終會變成:
[
[P00, P01, P02 ... P0d],
[P10, P11, P12 ... P1d],
[P20, P21, P22 ... P2d],
]
2、計算位置編碼
計算位置編碼有多種方式:固定位置編碼,相對位置編碼,絕對位置編碼,其中Transformer的作者設計了一種三角函數位置編碼方式,通過三角函數計算輸出位置編碼向量。
為什么三角函數可以作為計算位置編碼的函數?
- 首先我們來回顧一下三角函數的基本性質:函數具有周期性,取值范圍是[-1, 1]。
sin
- 其次,如果用絕對位置編碼計算最大序列為3的位置(0-7),二進制表示如下:
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1]
]
從上可以表示看出,較高比特位的交替頻率低于較低比特位,存在周期性bit位變化,符合三角函數的周期性,而且三角函數的取值范圍是[-1, 1],輸出浮點數,并且數據連續,比直接使用二進制更節省空間。
3、Transformer中的位置編碼層
假設你有一個長度為L的輸入序列,要計算第K個元素的位置編碼,位置編碼由不同頻率的正弦和余弦函數給出:
函數
- k:詞序列中的第K個元素
- d:詞向量維度,比如512,1024,8K等
- P(k, i):位置函數,輸出位置編碼向量
- n:定義的標量,Attention Is All You Need 的作者設置為 10,000
- i:映射到列索引,范圍是0~d/2(由于輸入是2i表示,如果用i表示,范圍可以是0~d)
按照上述Hello world!的例子,計算位置編碼結果如下:
計算結果
那么用代碼實現一個簡化版本的位置編碼:
import numpy as np
def getPositionEncoding(seq_len, d, n=10000):
P = np.zeros((seq_len, d))
for k in range(seq_len):
for i in np.arange(int(d/2)):
denominator = np.power(n, 2*i/d)
P[k, 2*i] = np.sin(k/denominator)
P[k, 2*i+1] = np.cos(k/denominator)
return P
P = getPositionEncoding(seq_len=3, d=3, n=100)
print(P)
# 輸出結果:
[[ 0. 1. 0. ]
[ 0.84147098 0.54030231 0. ]
[ 0.90929743 -0.41614684 0. ]]
4、大模型訓練中的位置編碼代碼
在我們從0訓練大模型中,其位置編碼的實現如下:
def precompute_pos_cis(dim: int, seq_len: int, theta: float = 10000.0):
"""預計算相對位置編碼的復數形式,用于旋轉位置編碼(RoPE)。"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 計算頻率
t = torch.arange(seq_len, device=freqs.device) # 創建時間步長
freqs = torch.outer(t, freqs).float() # 計算頻率的外積
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # 生成復數形式的頻率
return pos_cis # 返回預計算的復數位置編碼
def apply_rotary_emb(xq, xk, pos_cis):
"""應用旋轉位置編碼到查詢和鍵。"""
def unite_shape(pos_cis, x):
"""調整位置編碼的形狀以匹配輸入張量的形狀。"""
ndim = x.ndim # 獲取輸入的維度
assert 0 <= 1 < ndim # 確保維度有效
assert pos_cis.shape == (x.shape[1], x.shape[-1]) # 確保位置編碼形狀匹配
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # 生成新形狀
return pos_cis.reshape(*shape) # 調整位置編碼的形狀
# 將查詢和鍵轉換為復數形式
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_) # 調整位置編碼形狀
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) # 應用位置編碼并轉換回實數
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) # 同上
return xq_out.type_as(xq), xk_out.type_as(xk) # 返回與輸入類型一致的輸出
這里使用的是RoPE旋轉位置編碼,和相對位置編碼相比,RoPE 具有更好的外推性,Meta 的 LLAMA 和 清華的 ChatGLM 都使用該編碼,目前是大模型相對位置編碼中應用最廣的方式之一,具體原理由于篇幅原因就不講了,可以看看這篇文章:https://cloud.tencent.com/developer/article/2327751。
參考
(1)http://www.bimant.com/blog/transformer-positional-encoding-illustration/(2)https://hub.baai.ac.cn/view/29979
本文轉載自 ??周末程序猿??,作者: 周末程序猿
