用短輸入模擬長(zhǎng)樣本,高效拓展LLM上下文窗口,北大聯(lián)合MSRA提出PoSE
論文題目:PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training
論文鏈接:https://arxiv.org/abs/2309.10400
代碼鏈接:https://github.com/dwzhu-pku/PoSE
一、研究簡(jiǎn)介
大型語言模型(LLMs)通常有一個(gè)預(yù)定義的上下文窗口大小,這限制了它們?cè)陂L(zhǎng)輸入的場(chǎng)景中的使用。為了使 LLMs 適應(yīng)更長(zhǎng)的輸入,通常需要用目標(biāo)長(zhǎng)度的樣本對(duì)其進(jìn)行微調(diào)(全長(zhǎng)微調(diào)),由此導(dǎo)致訓(xùn)練成本十分昂貴。
舉例來說,在 Positional Interpolation [1] 這份工作中,將 LLaMA 的上下文窗口從 2048 拓展到 8192 使用了 32 張 A100,對(duì)于更大的上下文窗口則使用了 128 張 A100。
為了將訓(xùn)練長(zhǎng)度與目標(biāo)長(zhǎng)度解耦合,以實(shí)現(xiàn)高效的上下文窗口擴(kuò)展,我們提出了一種稱為位置跳躍式訓(xùn)練(Positional Skip-wisE training, PoSE)的方法,在原始的上下文窗口中模擬更長(zhǎng)的訓(xùn)練樣本。
如下圖所示,我們將原始的上下文窗口分成幾塊,然后引入不同的 bias 項(xiàng)來調(diào)整每個(gè)塊的位置編碼。對(duì)于每一條訓(xùn)練樣本,這些 bias 項(xiàng)和塊的長(zhǎng)度都會(huì)發(fā)生變化,因此通過大量的訓(xùn)練,模型能適應(yīng)目標(biāo)長(zhǎng)度內(nèi)的所有位置。
實(shí)驗(yàn)結(jié)果表明,PoSE 有以下三方面的優(yōu)勢(shì):
- 訓(xùn)練的時(shí)空效率:由于只需要按照原始的上下文長(zhǎng)度進(jìn)行訓(xùn)練,PoSE避免了由目標(biāo)上下文長(zhǎng)度增加帶來的平方級(jí)別的計(jì)算復(fù)雜度,使得訓(xùn)練對(duì)于內(nèi)存和時(shí)間的開銷都大大減小。
- 能支持極長(zhǎng)的上下文:通過解耦合訓(xùn)練長(zhǎng)度和目標(biāo)長(zhǎng)度,我們僅使用2k的訓(xùn)練窗口就成功將 LLaMA 拓展到 128k。
- 兼容所有基于 RoPE 的模型和位置插值策略:PoSE 的有效性在 LLaMA、GPT-J、Baichuan 等多種基礎(chǔ)模型,和 Linear、NTK、YaRN 等多種插值策略上得到了驗(yàn)證。
二、技術(shù)背景
旋轉(zhuǎn)位置編碼 RoPE:RoPE 是當(dāng)下主流的位置編碼方式,被 LLaMA、GPT-J 等大語言模型所采用。給定一個(gè) 維的隱向量
和位置
,RoPE 通過如下方式編碼位置信息:
其中 。此前的絕對(duì)位置編碼多是直接作用在輸入向量
上,與之不同的是,RoPE 是在作用在每一層的 query 和 key 向量上。RoPE 可以看作是一種相對(duì)位置編碼,給定位置
處的 query 向量
和位置
處的 key 向量
,注意力分?jǐn)?shù)
可以寫成如下函數(shù):
上下文窗口擴(kuò)展:給定一個(gè)以為原始上下文窗口長(zhǎng)度的大語言模型,我們的目標(biāo)是其支持的上下文長(zhǎng)度拓展到
,使得在
個(gè)輸入內(nèi)能較好地保持原有的性能。
位置插值(PI):為了將 LLM 的上下文窗口從?
拓展到 ,一種直接的做法是使用
長(zhǎng)的輸入文本
,設(shè)定其位置編碼為
, 對(duì) LLM 進(jìn)行微調(diào)。
然而,實(shí)踐表明 [1] [2],這部分的位置在前向傳播時(shí)會(huì)產(chǎn)生災(zāi)難性的離群值,從而導(dǎo)致訓(xùn)練無法達(dá)到預(yù)期的效果。這主要是因?yàn)槟P驮陬A(yù)訓(xùn)練時(shí)只見過
這些位置,無法很好的泛化到外推出去的這部分位置。
為了解決這個(gè)問題,Position Interpolation [1] 這份工作首先提出用“內(nèi)插”代替“外推”,設(shè)定縮放因子,并將上述注意力公式修改為
(也就是將位置編碼線性修改為
)。
這種方式可以減少離群值的出現(xiàn),將上下文窗口拓展到了 32k。在此基礎(chǔ)上,NTK 提出通過修改 來進(jìn)行位置插值,取得了更好的效果。YaRN 則根據(jù)不同的維度,對(duì)上述線性插值和 NTK 進(jìn)行了整合。
三、方法描述
盡管上述 Linear / NTK / YaRN 等插值方式能一定程度上解決位置外推的問題,他們?nèi)匀恍枰媚繕?biāo)長(zhǎng)度的訓(xùn)練樣本來訓(xùn)練模型(即全長(zhǎng)微調(diào))。
隨著目標(biāo)長(zhǎng)度的增加,平方級(jí)別的計(jì)算復(fù)雜度帶來的開銷依舊是難以承受的。因此,在插值技術(shù)的基礎(chǔ)上,我們提出調(diào)整原始的上下文窗口中的位置編碼,來模擬更長(zhǎng)的訓(xùn)練樣本,從而實(shí)現(xiàn)高效的上下文擴(kuò)展。
位置編碼的調(diào)整主要有兩個(gè)考量:
- 為了避免推理時(shí)遇到 out-of-distribution 的相對(duì)位置,調(diào)整后的位置編碼應(yīng)覆蓋 所有這些相對(duì)位置;
- 用調(diào)整后的位置編碼來微調(diào) LLM 不應(yīng)該損害其原有性能,因此調(diào)整后的位置編碼的結(jié)構(gòu)應(yīng)該和預(yù)訓(xùn)練時(shí)盡可能接近。
第一步:我們將原上下文窗口 分成 N 個(gè)塊
,每個(gè)塊的長(zhǎng)度為
,滿足
。記
的起始位置編碼為
,則這個(gè)塊的位置編碼如下:
第二步:我們從離散均勻分布 中采樣出跳躍偏置項(xiàng)
,并施加到
的位置編碼上:
為了避免塊之間位置編碼的重合,我們施加了 這一限制。值得注意的是,對(duì)于每條數(shù)據(jù),我們會(huì)重新采樣每個(gè)塊的大小和跳躍偏置項(xiàng)。直觀上來說,通過這種方式,我們擴(kuò)大了原上下文窗口能覆蓋的相對(duì)位置范圍,并且位置編碼的不連續(xù)只發(fā)生在塊之間,因此盡可能地保持了預(yù)訓(xùn)練階段的位置編碼結(jié)構(gòu)。
第三步:選定每個(gè)塊內(nèi)的內(nèi)容。給定輸入文本,我們用類似的方法來抽取每個(gè)塊內(nèi)的填充的內(nèi)容:
我們也嘗試了其它 的賦值方式,如
,此時(shí)塊間的內(nèi)容也是連續(xù)的;或如
,此時(shí)調(diào)整后的位置編碼恰好對(duì)應(yīng)訓(xùn)練數(shù)據(jù)在原始文本中的位置。實(shí)驗(yàn)結(jié)果表明,這幾種賦值方式并沒有明顯的差別。
第四步:位置插值及超參初始化。我們使用位置插值來使訓(xùn)練更穩(wěn)定。 和
設(shè)置成 0,N 設(shè)置為 2 。
四、實(shí)驗(yàn)分析
1. 實(shí)驗(yàn)設(shè)置
訓(xùn)練過程:我們主要使用 LLaMA-7B 作為基模型,對(duì)于所有設(shè)定都只訓(xùn)練 1000 步,訓(xùn)練時(shí)長(zhǎng)度為 2k,batch size 為 64。我們使用 8 張 V100 進(jìn)行訓(xùn)練,1 張 A100 進(jìn)行推理。對(duì)于我們的方法和各個(gè) baseline,我們都默認(rèn)采用線性插值來使訓(xùn)練更穩(wěn)定。
Baseline:
- 全長(zhǎng)微調(diào)(Full-length fine-tuning)
- 隨機(jī)位置(RandPos):給定目標(biāo)長(zhǎng)度 和原始長(zhǎng)度 ,從 中隨機(jī)采樣 個(gè)位置,按升序排列,作為位置編碼。
2. 主要結(jié)果
語言模型:
我們使用滑動(dòng)窗口的方式來計(jì)算困惑度 PPL。在 GovReport 和 Proof-Pile 兩個(gè)數(shù)據(jù)集上,PoSE 的性能和 Full-length 十分接近,遠(yuǎn)超未做窗口擴(kuò)展的版本(Original)和隨機(jī)位置的版本(RandPos)。且隨著窗口長(zhǎng)度從 2k 增加到 32k,PPL 呈下降趨勢(shì),說明拓展后的模型能充分利用更長(zhǎng)的上下文信息。
密碼檢索:
在密碼檢索任務(wù)上,利用 PoSE 拓展到 16k 和 32k 的模型能分別在 16k 和 32k 的上下文內(nèi)取得接近 100% 的密碼檢索準(zhǔn)確率,說明模型能關(guān)注到目標(biāo)長(zhǎng)度內(nèi)的每個(gè)位置。
時(shí)空效率:
在時(shí)空效率方面,全長(zhǎng)微調(diào)的訓(xùn)練時(shí)長(zhǎng)和內(nèi)存消耗隨目標(biāo)長(zhǎng)度的增加而迅速增長(zhǎng),相比之下,PoSE 需要的訓(xùn)練時(shí)間和內(nèi)存較為穩(wěn)定。并且在每個(gè)時(shí)間步上,性能和全長(zhǎng)微調(diào)都很接近。
兼容性:
兼容性方面,PoSE 可以適配 LLaMA、LLaMA2、GPT-J、Baichuan2 等各種基于 RoPE 的基礎(chǔ)模型,以及 Linear、NTK、YaRN 等各種插值策略,展現(xiàn)出較好的普適性。其中 NTK 在最后階段會(huì)有一個(gè) PPL 的突增,這主要是因?yàn)榻o定縮放因子 ,NTK 實(shí)際實(shí)現(xiàn)的縮放倍數(shù)會(huì)略小于 [3]。YaRN 解決了這個(gè)缺陷,取得了三者中最好的效果。
超長(zhǎng)上下文拓展的潛力:
只使用 2k 的訓(xùn)練長(zhǎng)度和 1000 步的訓(xùn)練步數(shù),我們嘗試了將 LLaMA 模型拓展到 128k。實(shí)驗(yàn)表明,在使用 YaRN 的情況下,模型在 128k 的窗口下仍然能保持較低的 PPL。
原窗口內(nèi)的語言能力:
最后,我們分析了經(jīng)由 PoSE 訓(xùn)練過后的模型在原窗口內(nèi)的語言能力。可以看出,和全長(zhǎng)微調(diào)以及原始模型相比,PoSE 模型能力的損失非常微小,這說明 PoSE 在拓展上下文窗口的同時(shí)較好地保持了模型的基礎(chǔ)能力。
五、總結(jié)與討論
本文提出了一種位置跳躍式訓(xùn)練(PoSE)來高效的拓展大語言模型的上下文窗口。通過調(diào)整位置編碼,PoSE 在原始的上下文窗口中模擬更長(zhǎng)的訓(xùn)練樣本,以達(dá)到解耦合訓(xùn)練長(zhǎng)度和目標(biāo)長(zhǎng)度的目的。
實(shí)驗(yàn)結(jié)果表明 PoSE 在和全長(zhǎng)微調(diào)保持同等性能的情況下,大大縮小了訓(xùn)練所需的時(shí)空開銷,并表現(xiàn)出良好的普適性和超長(zhǎng)上下文擴(kuò)展的潛力。我們相信 PoSE 將大大降低上下文窗口拓展的成本,使更多人可以參與到相關(guān)的研究中來,從而推動(dòng)長(zhǎng)上下文建模領(lǐng)域的快速發(fā)展。
PoSE 完成于 2023 年 9 月,我們相信這種位置跳躍的思路是 Long Context 的有效解決方案。結(jié)合近幾個(gè)月來 Long Context 相關(guān)研究的進(jìn)展,我們認(rèn)為 PoSE 可能有以下一些方面值得進(jìn)一步探究:
- 應(yīng)用范圍的拓展:融合到預(yù)訓(xùn)練或者 SFT 階段中。PoSE 的實(shí)驗(yàn)主要是對(duì)基礎(chǔ)模型進(jìn)行輕量級(jí)的 post-pretrain,如果能直接融入預(yù)訓(xùn)練中,可能可以更好的解決推理時(shí)長(zhǎng)文本位置 out-of-distribution 的問題;如果能用于 SFT 中,則可以將模型更好地適配到具體的下游任務(wù)或者 alignment 要求上。
- 效果的提升:探索更優(yōu)的 skip 策略和合適的訓(xùn)練數(shù)據(jù)配比。一方面,實(shí)驗(yàn)中只將窗口分成了兩塊,且跳步由隨機(jī)采樣決定,如何在更多場(chǎng)景設(shè)計(jì)更科學(xué)合理的 skip 結(jié)構(gòu)指的關(guān)注。
另一方面,PoSE 簡(jiǎn)單從 The Pile 數(shù)據(jù)集采樣了部分超過 2k 長(zhǎng)的樣本作為訓(xùn)練數(shù)據(jù),并沒有特別關(guān)注數(shù)據(jù)的來源、長(zhǎng)度等配比。根據(jù) Fu et al. (2024) [4] 的結(jié)論,優(yōu)化訓(xùn)練數(shù)據(jù)的分布對(duì)于長(zhǎng)上下文建模能力的獲取以及模型原始能力的保持都能有較大的幫助。
本文轉(zhuǎn)載自PaperWeekly
