如何擴展大模型的上下文長度
一、背景
大模型的上下文長度是指我們在使用大模型的時候,給大模型的輸入加上輸出的字符(Token)總數,這個數字會被限制,如果超過這個長度的字符會被大模型丟棄。目前開源的大模型上下文長度一般不長,比如 Llama 2 只有 4K,Code-Llama 系列因為需要輸入代碼,擴展到了 16K。閉源系列模型的提供了更長的上下文長度,比如 OpenAI 在其最新模型 GPT-4 Turbo 中提供了 128K 的上下文長度,Anthropic 的 Claude 2.1 模型提供了 200K 上下文長度。
一些場景需要較長上下文,比如,文檔翻譯需要將整篇文檔輸入給大模型進行翻譯,長文檔內容抽取需要大模型讀取整篇長文檔進行內容抽取,會議內容總結則需要給大模型輸入會議聊天記錄進行總結等。
想要得到一個長上下文的大模型,一般有兩種途徑。一種是大模型在初始階段被設置為長上下文,然后經過預訓練,指令微調,對齊訓練等方式得到一個長上下文大模型。另外一種方式是選擇已經訓練好的大模型,通過技術改造擴展其上下文長度,然后再進行微調訓練得到長上下文模型。
圖片
本文將基于比較火的 Llama 2 大模型的結構[1]介紹上下文長度的方法與挑戰,然后探討一些業界流行的上下文長度擴展的技術,最后給大家推薦下 KubeAI 大模型訓練推理平臺可以上手實驗。
二、LLAMA的結構
Transformer的結構
通常所說的大模型是指大語言模型(Large Language Model,LLM),其模型結構一般基于 Transformer 進行改進而來。Transformer 源自于 2017 年 Google 發表的著名論文"Attention Is All You Need[2]"。論文中的 Transformer 結構如下,下圖源自論文[2]。
圖片
它的結構包括兩部分:Encoder(編碼器)和Decoder(解碼器)。Encoder 與 Decoder 大致都包含以下層,每一層都有特定的功能,下面為 Encoder(編碼器)各層的簡單介紹:
輸入嵌入層(Input Embedding Layer):將輸入文本的詞或標記轉換為向量表示,以便模型能夠理解它們。
多頭自注意力層(Multi-Head Self-Attention Layer):幫助模型捕捉輸入序列中詞與詞之間的關系,使模型能夠了解上下文信息。
前饋神經網絡層(Feed-Forward Neural Network Layer):對多頭自注意力的輸出進行進一步的特征提取和變換,以增加模型的表示能力。
歸一化層(Layer Normalization Layer):規范化每一層的輸出,有助于訓練過程的穩定性。
總的來說,Transformer 是一種強大的模型,它可以捕捉文本和序列數據中的長距離依賴關系,使其在翻譯、對話、摘要生成等自然語言處理任務中表現出色。這個模型已經在各種應用中取得了顯著的成功。感興趣的同學可以自行去網上搜索下 Transformer 的結構,深入了解。
LLAMA的結構
目前大多數生成式語言模型如 Llama 系列,僅僅采用了 Transformer 的 Decoder 模塊結構。在 Huggingface 中,這種結構通常被稱為 CausalLM,即因果語言模型(Causal Language Model)。下面我們來具體看一下 Llama 2 模型系列的結構,Llama 2 相關的論文[1]。
(下圖是基于 Transfomers 代碼中的 LlamaModel 繪制而成,具體代碼參考 Transfomer 中的 modeling_llama.py[3])
圖片
我們來解讀下上面 Llama 各層的結構與作用,首先從輸入文本開始。會經過下面各層:
- Input Embedding:將 Input 文本轉化為向量表,通過 nn.Embedding 實現。
- Llama Decoder Layer:Decoder 采用多層 Llama Decoder Layer。每一層包括自注意力(Llama Attention)和前饋網絡(Llama MLP)。自注意力用于捕捉文本中的長程依賴關系。前饋網絡進行非線性映射。
- Llama RMSNorm:一種規范化方式,用于正則化每層的輸,起到預處理的作用。
- lm_head:一個線性層,將 Decoder 最后一層的輸出映射到詞典大小的維,以進行后續的語言模型 Logits 計算。
- Llama Attention:多頭自注意力機制,用于建模文本中的依賴關系。將輸入表示切分為多個頭,然后在每個頭內做點積注意力運算。
- Llama MLP:采用 Gated Linear Units 的多層前饋網絡。進行非線性變換來捕捉復雜模式。
總體上,Llama 通過堆疊多層自注意力和前饋網絡來表示文本語義,然后預測后續詞元。lm_head 負責將語義表示映射為具體的詞典 Logits。整個模型端到端通過語言模型目標進行訓練。
上面各層中比較核心的是 Llama Attention 層,該層的結構如下:
圖片
論文“Attention Is All You Need”[2]中描述的 Attention 的計算公式如下:
圖片
Llama 的 Attention 計算過程如下:
- 輸入會經過線性變換,得到 Query(Q)、Key(K)和 Value(V)矩陣。
- 對 Q 和 K 應用 RoPE 位置編碼。RoPE 包含旋轉的 Sin 和 Cos 編碼,會根據每個 Token 的位置對其表示進行旋轉。
- 用旋轉后的 Q 和 K 計算點積,得到注意力權重 Attention Score,經過 Softmax 計算后得到 Normalized Attention Weight。
- 再把 Attention Weight 與 V 相乘,并進行加權求和,得到 Attention 的輸出。
- 輸出再經過一個線性變換,繼續輸出給下一層作為輸入。
這樣通過 RoPE 位置編碼、加權平均,Llama 的 Attention 可以高效穩定地提取文本序列的上下文語義信息。
三、擴展方案與挑戰
位置編碼層(RoPE)
通過上面對 Llama 結構的解析,我們看到在 Llama Attention 中有一個叫做 RoPE(旋轉位置編碼)的層,主要用于對輸入進行位置編碼,讓模型學到輸入文本中每個 Token 的位置關系,從而更好地理解輸入。RoPE 層能處理序列的長度決定了 Llama 的上下文長度,要擴展 Llama 的上下文長度,需要對 RoPE 層進行改造和擴展。下面我們先簡單介紹下 Llama RoPE 層的工作原理。
RoPE 旋轉位置編碼,最早來自 RoFormer:Enhanced Transformer with Rotary Position Embedding[4]這篇論文。下圖源自論文[4]。
圖片
RoPE 層是一種相對位置編碼方法,它給輸入的每個 Token 編碼一個向量,向量中的每個值表示該 Token 與其他 Token 的相對距離。論文以二維向量為例,解釋了這種位置編碼為什么叫做旋轉位置編碼。如上圖所示,在二維平面,相當于把向量旋轉了一個 Q 的角度。
論文中證明,在進行旋轉位置編碼之后,可以從新的編碼向量中獲取原向量的相對位置信息,即論文中下面的公式中的 m-(位置 m 減去位置 n)。下圖源自論文[4]。
圖片
而我們只需要理解旋轉位置編碼的最終計算公式如下,對于一個輸入向量 X,直接與一個 COS 矩陣和 SIN 矩陣進行內積后求和:下圖源自論文[4]。
圖片
其中 WCOS 和 WSIN 分別是一個預先固定的 COS 矩陣和 SIN 矩陣。
下面我們展示下 Huggingface 的 Transformer 對應的計算 COS 和 SIN 矩陣的計算代碼:
圖片
擴展位置編碼層(RoPE)支持長上下文
上面介紹了 Llama 結構中的旋轉位置編碼層 RoPE。要擴展大模型的上下文長度,就需要擴展 RoPE 層,也就是擴展其 COS 和 SIN 矩陣,讓RoPE支持更長序列的輸入。
RoPE 中的 COS 和 SIN 矩陣維度(seq_length, embed_dim),其中 seq_length 就是模型支持的最大序列長度,embed_dim 是詞嵌入維度。矩陣中的每個值表示一個位置上的正弦或余弦編碼。為了支持更長的上下文,需要重新計算更大尺寸的 COS 和 SIN 矩陣。
對于未訓練過的大模型,只需要直接更改其配置文件中的 max_position_embeddings 即可實現 RoPE 層的擴展,然后再進行訓練。但是對于已經訓練過的模型,如果直接修改其配置,會導致模型的效果急劇下降,后面第四部分我們會介紹一些基于已有模型進行改造擴展的上下文的方法。max_position_embeddings 的配置如下:
圖片
超長上下文面臨的挑戰
超長上下文的大模型部署推理的時候,往往會面臨如下性能挑戰。
- 推理時間變長
從上面的 Attention 的計算公式可以看出,Attention 進行了點積的運算,其時間復雜度為 L(序列長度)的平方。也就是說大模型在推理的時候,輸入的序列長度越長推理時間越多。所以超長上下文的大模型需要更多的推理時間,這會帶來用戶體驗上的損失。
- 推理顯存空間變大
大模型在持續推理的過程中,需要緩存一個叫做 KV Cache 的數據快,KV Cache 的大小也與序列長度成正比。以 Llama 2 13B 大模型為例,一個 4K 長的序列大約需要 3G 的顯存去緩存 KV Cache,16K 的序列則需要 12G,128K 的序列則需要 100G 顯存。
圖片
超長上下文的大模型需要更多的 KV Cache 存儲空間,但是 GPU 顯存非常珍貴,比如 A100 也只有 40G 或 80G 顯存兩個版本,這對本來就比較緊張的 GPU 顯存來說是一個很大的挑戰。
大模型上下文擴展的思路
綜上所述,擴展大模型的上下文長度,一般思路如下:
圖片
- 首先通過對位置編碼層進行改造,使其支持更長的上下文。
- 為了取得更好的推理性能,還需要對 Attention 計算進行優化。
- 進行微調訓練,讓大模型適應新的模型結構。
四、位置編碼層改造擴展上下文的案例
上面我們講到,從模型物理結構上擴展上下文長度,需要直接修改 RoPE 層,即直接擴展其 SIN 和 COS 矩陣。但是大模型都是基于大量短序列數據訓練得到的。如果直接強行擴展,會導致模型困惑度提高。所謂困惑度是模型對下一個詞的預測困難程度的量化指標,直觀意義是大模型的輸出是否能夠更容易被人類所理解。
因此,我們需要更好的方法來擴展預訓練模型的上下文長度,既要兼顧模型性能,又要控制困惑度。下面我們概括幾種業界常用的上下文長度擴展方法。
圖片
線性位置插值法
線性插值法的思想最早來自于這篇文章,Extending Context Window of Large Language Models via Positional Interpolation[5],該方法已經被 Huggingface 的 Transformer 中 Llama 模型代碼集成。
下圖源自論文:
圖片
思路:通過線性縮小輸入位置索引以匹配原始上下文窗口大小,而不是超出訓練上下文長度進行外推,這樣可以減小注意力機制中相對位置的影響,幫助模型更容易適應擴展后的上下文窗口。
效果:在從 LLaMA 7B 到 65B 模型上,通過位置插值擴展上下文窗口到 32768(4k擴展到32K),僅需微調 1000 步,就能在包括語言建模、長文檔摘要撰寫等任務上取得良好效果。
優點:位置插值不僅能有效擴展上下文窗口,提高模型在長上下文任務上的性能,還能在原有上下文窗口大小的任務上保持模型質量,且不需要額外的權重或修改模型架構。
缺點:需要重新訓練,有時候擴充后會導致模型困惑度上升。
動態插值法(NTK-awared)
動態插值法是在位置插值法的基礎上演變而來的,最早提出文章 NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning[6]。現在也被 Huggingface 的 Transformer 中 Llama 模型代碼集成。
下面是 Chinese Llama 基于位置插值法與動態插值法進行的比較,數據來自 Extend Context Size Without Fine-Tuning[7]。
圖片
數據顯示,相比與位置插值法,NTK 動態插值法不會顯著增加大模型的困惑度。
思路:利用神經正切核 (NTK) 理論,設計非線性位置編碼插值方案,改變基數而不是縮放比例,使不同位置可區分,避免線性插值的問題。
效果:與線性插值相比,在大幅擴展上下文(如8000+)時,無需微調就可以使困惑度下降極小。
優點:微調成本極低,上下文窗口可以擴展很大,困惑度變化小。
總體來說,動態插值法通過考慮模型特性,設計更優化的插值方案,能夠在不增加訓練成本的條件下,獲得接近無損的上下文窗口擴展效果。這為進一步擴展和優化大語言模型提供了新的思路。
Yarn(NTK升級)
Yarn 擴展上下文的方法來自于文章 YaRN: Efficient Context Window Extension of Large Language Models[8]。代碼參考Yarn[9]。并給出基于 Llama2 的 128K 上下文擴展。
下圖源自論文[8]:
圖片
數據顯示,在 128K 的 Proof-Pile 數據集上評測,Yarn-Llama-2-7b-128K/64K 模型的困惑度仍然保持良好下降。
相比于僅進行簡單線性插值或動態插值的方法,Yarn 方法更全面地考慮了不同頻率 RoPE 維度的作用,避免了信息損失和外推問題。這使得 Yarn 方法在不 Fine-Tuning 的情況下,以及 Fine-Tuning 的數據量很少的情況下,都能更好地擴展上下文窗口。
RoPE 中的每個維度對應著不同的正弦波頻率。高頻的正弦波編碼了位置信息的細微變化,低頻的正弦波編碼了位置信息的整體趨勢。
如果我們簡單地進行線性插值,會把所有頻率的正弦波都等比例地拉伸。這會導致兩個問題:
- 高頻正弦波被過度拉伸,導致代表細微位置變化的信息丟失。這個會影響模型區分很接近的詞的能力。
- 低頻正弦波被拉伸,不同位置之間的相對距離變小。這會導致模型判斷近距離詞的先后順序變得困難。
為了解決這個問題,Yarn 方法對不同頻率的正弦波進行不同程度的插值:
- 對高頻正弦波幾乎不進行插值,保留細微位置信息。
- 對低頻正弦波進行接近線性的插值,保留位置大體信息。
- 中頻正弦波進行漸變的插值。
這樣既保留了高頻表示細微位置變化的信息,也保留了低頻表示位置整體關系的信息,避免了簡單線性插值的問題。
五、優化Attention擴展上下文的案例
上面我們提到,長上下文對大模型的正向與反向傳播的性能來說是個挑戰。其主要原因是 Attention(注意力)的計算復雜度比較高,為了解決這個問題,業界提出了很多優化 Attention 計算的方法。
LongLoRA方法
LongLoRA 是香港中文大學聯合 MIT 提出的一種模型微調方法,其論文 LONGLORA: EFFICIENT FINE-TUNING OF LONG-CONTEXT LARGE LANGUAGE MODELS[10]。下圖是論文[10]中描述的方法:
圖片
LongLoRA 提出了一種移位稀疏注意力(Shifted Sparse Attention,S2-Attn)來近似標準的自注意力。
在傳統的自注意力(self-attention)中,模型需要計算輸入序列中所有元素對之間的注意力權重,這在處理長序列時會導致計算復雜度呈二次方增長。S2-Attn 在訓練時,將輸入序列劃分成若干個組,在每個組內部進行自注意力計算。為了使不同組之間有信息流通,在一半的注意力頭內,向其中一個組的 Tokens 做平(Shift)操作,平移的長度為組的一半。這樣就引入了不同組之間的信息交換,又不增加計算量。
S2-Attn 的設計使得大型語言模型能夠在處理長序列時保持較高的性能,同時顯著降低了訓練和推理時的計算資源需求。
實驗結果如下,用了一個 8 個 A100 的機器微調,將 Llama2 7B/13B/70B 模型分別擴展到 100K,64K, 32K 長度,而大模型的困惑度并沒有明顯變化。
圖片
六、業界更多擴展上下文的方法
最近 Technology Innovation Institute(TII)發表了一篇論文綜述The What, Why, and How of Context Length Extension Techniques in Large Language Models – A Detailed Survey[11],調研了業界的擴展大模型的技術,下圖源自論文。
圖片
論文中介紹了更多的業界上下文擴展的方法,大致可以簡單分為一下幾大主要方式。
圖片
七、Kubeai大模型訓練推理平臺
上面我們分別講解了 Llama 的結構,然后基于 Llama 的結構去講解了業界最新擴展大模型上下文長度的方法與效果。
圖片
在 KubeAI 訓練推理平臺上,用戶只需要上傳數據、選擇大模型,就可以完成一次訓練和推理部署。如果想了解詳細使用方法,可以參考我們之前發表的關于得物大模型平臺的系列介紹的文章。
KubeAI 平臺為用戶提供了非常便捷的大模型訓練和部署功能。用戶無需關注底層基礎設施,就可以通過簡單的步驟上傳數據、配置參數、選擇模型,從而獲得針對自己業務自定義的大模型。
八、總結與展望
本文從 Llama 大模型的結構入手,介紹了其模塊結構,重點解析了 Attention 機制中的 RoPE 層。要實現 Llama 模型上下文長度的擴展,需要對應擴展 RoPE 位置編碼層。但是直接擴展會導致模型困惑度上升,針對這個問題,我們介紹了業界常見的幾種上下文擴展方法,包括位置查找法、動態插值法和 Yarn 方法等。
長下文推理對性能要求比較高,為此我們也介紹了一些為了提升性能而優化 Attention 的方法,比如 LongLoRA[10] 這篇論文的 S2-Atten 的方法。有興趣的同學可以閱讀相關論文了解細節。
本文通過剖析 Llama 模型結構,解析上下文擴展的關鍵層 RoPE,并概述各種擴展方法的原理,希望能夠幫助大家對大模型上下文擴展有一個系統的了解。后續如果有機會,我們會繼續分享更多大模型的核心技術,讓更多人對大模型的內在機制有更深的認識。歡迎持續關注我們的內容和分享!
參考資料
- [1] Llama 2: Open Foundation and Fine-Tuned Chat Models(https://arxiv.org/abs/2307.09288)
- [2] Attention Is All You Need(https://arxiv.org/pdf/1706.03762.pdf)
- [3]modeling_llama.py(https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)
- [4] RoFormer: Enhanced Transformer with Rotary Position Embedding(https://arxiv.org/abs/2104.09864)
- [5] Extending Context Window of Large Language Models via Positional Interpolation(https://arxiv.org/abs/2306.15595)
- [6] NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning(https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/)
- [7] Extend context size without fine-tuning(https://github.com/ymcui/Chinese-LLaMA-Alpaca/pull/705)
- [8] YaRN: Efficient Context Window Extension of Large Language Models(https://arxiv.org/abs/2309.00071)
- [9] Yarn(https://github.com/jquesnelle/yarn)
- [10] LONGLORA: EFFICIENT FINE-TUNING OF LONG-CONTEXT LARGE LANGUAGE MODELS(https://arxiv.org/pdf/2309.12307.pdf)
- [11] The What, Why, and How of Context Length Extension Techniques in Large Language Models – A DetailedSurvey(https://arxiv.org/pdf/2401.07872.pdf)