為什么大語言模型難以處理長上下文?從 Transformer 到 Mamba
OpenAI 在兩年前推出 ChatGPT 時,其能夠處理的上下文信息僅有 8,192 個 tokens1。換言之,如果輸入的文本超過大約 15 頁,它就會“遺忘”最初的上下文內容。這一限制使得 ChatGPT 在處理任務時的規模和復雜度都受到了影響。
而現今的 LLMs 能力有了顯著提升:
- OpenAI 的 GPT-4o[1] 現在能夠處理多達 128,000 個 tokens 的上下文。
- Anthropic 的 Claude 3.5 Sonnet[2] 可以處理 200,000 個 tokens 的上下文。
- Google 的 Gemini 1.5 Pro[3] 更是擁有 2 百萬個 tokens 的上下文處理能力。
盡管如此,要想讓 AI 系統達到人類水平的認知能力,我們還需要取得更多的進步。
許多人展望未來,認為 AI 將能夠承擔大部分甚至全部的人類工作。然而,人類在工作生涯會閱讀和聽到數以億計的文字,并且還能通過視覺、聽覺和嗅覺從周圍環境中獲取更多信息。要使 AI 達到人類智能水平,它們也需要具備處理如此大量信息的能力。
目前,處理大量信息的最流行 LLM 系統構建方法是“檢索增強生成”(RAG)。這類系統會尋找與用戶查詢相關的文檔,并將最相關的部分嵌入到 LLM 的上下文中。
盡管 RAG 系統在某些情況下能夠有超越傳統搜索引擎的表現,但目前這類系統仍存在諸多不足。它們只有在成功將最關鍵的文檔嵌入 LLM 的上下文時,才能產出滿意的結果。然而,用于檢索這些文檔的技術,通常是在向量數據庫[4]中進行搜索 —— 并不夠精細。如果用戶提出的問題復雜或含糊不清,RAG 系統很可能會錯誤地檢索文檔,導致聊天機器人給出錯誤的回答。
此外,RAG 系統并未讓 LLM 在處理大量文檔時展現出更高級的推理能力:
- 例如,律師可能需要 AI 幫助審閱和總結數十萬封電子郵件。
- 工程師可能需要 AI 分析數千小時的工廠監控視頻。
- 醫學研究者可能需要 AI 在數以萬計的患者病歷中識別趨勢。
這些任務任何一個都可能需要超過 200 萬個 tokens 的上下文處理能力。而且,我們希望 AI 系統在完成這些任務后,不是一切從頭開始,而是能夠像人類工作者一樣,通過經驗積累不斷提升。計算機的超強記憶力和耐力一直是其重要優勢,在 AI 時代,我們并不想放棄這些特性。但目前 LLMs 在吸收和解讀大量信息的能力上,還遠未能達到人類水平。
確實,LLMs 在訓練過程中吸收的信息量遠遠超過了人類。最新的人工智能模型已經在數萬億個 tokens 上進行了訓練,這遠遠超過了一個人一生中所能閱讀或聽到的信息量。然而,許多有價值的資料是保密的、具有時效性的,或者因為其他原因無法用于訓練。
因此,我們希望 AI 模型在推理階段能夠閱讀并記住遠超 200 萬個 tokens 的信息。但這并非易事。
基于 transformer 的 LLMs 的核心創新在于“注意力”機制,這是一種數學運算,使得模型能夠“回顧”之前的 tokens。在 LLM 生成新 token 之前,它會執行一次注意力操作,將當前 token 與之前的所有 tokens 進行比較。這導致傳統的 LLMs 在上下文增長時效率逐漸降低。
目前,許多人正在研究解決這一問題的方法,我將在本文后續部分討論其中的一些方案。但在此之前,我需要解釋一下我們是如何從一開始就形成了這樣一個復雜的架構。
1.GPUs 讓深度學習成為現實
個人電腦的核心 —— 中央處理單元(CPUs) ,曾是通過提高時鐘頻率來提升性能的。但進入 21 世紀初期,由于過熱問題,芯片制造商大多放棄了這種提速方法。
芯片制造商轉而開始研發能夠同時處理多個指令的 CPU[5]。然而,它們的進步受到了需要指令按順序執行的傳統編程模式的限制。
為了充分發揮摩爾定律[6]的潛力,一種全新的架構應運而生,那就是 Nvidia 推出的 GPUs。
1999 年,Nvidia 開始銷售 GPU,旨在加快 3D 游戲如《Quake III Arena》的渲染速度。這些作為 PC 擴展卡的 GPU,任務是迅速繪制構成游戲中墻壁、武器、怪物等物體的成千上萬的三角形。
這種任務不需要順序編程:屏幕上不同區域的三角形可以任意順序繪制。因此,Nvidia 的首款 GPU[7] 并不是采用單個處理器逐個執行指令,而是擁有十幾個專用核心 —— 類似于微型的CPU —— 它們并行作業,共同繪制場景。
隨著摩爾定律的發展,Nvidia 制造的 GPU 計算核心數量從數十個增加到數百個,最終甚至達到數千個。人們逐漸意識到,GPU 強大的并行計算能力不僅可以用于視頻游戲,還能應用于其他領域。
2012 年,多倫多大學的計算機科學家 Alex Krizhevsky、Ilya Sutskever 和 Geoffrey Hinton 利用兩塊 Nvidia GTX 580 GPUs[8] 訓練了一個用于圖像識別的神經網絡。這兩塊 GPU 各自擁有 512 個核心,提供了巨大的計算能力,使他們能夠訓練出一個擁有 6000 萬個參數的神經網絡。他們在 ImageNet 圖像分類競賽[9]中取得了新的準確率紀錄[10](該競賽的目標是將圖像歸類到 1000 個不同類別之一)。
不久之后,研究人員開始將這些技術應用于更多領域,自然語言處理便是其中之一。
2.Transformers 打破了自然語言理解的瓶頸
在 2010 年代初,循環神經網絡(RNNs)是處理自然語言的主流架構。RNNs 采用逐詞處理的方式。神經網絡在處理完每個單詞后,會更新其隱藏狀態(hidden state),這是一組數字,代表了神經網絡對句子當前理解的程度。
RNNs 在處理短句時表現尚可,但面對長句時就顯得力不從心,更別提段落或更長的文本了。在分析長句時,RNN 有時會“遺忘”句子開頭的關鍵詞。2014 年,計算機科學家 Dzmitry Bahdanau、KyungHyun Cho 和 Yoshua Bengio 發現[11],通過引入一個注意力機制,允許網絡“回顧”句子中的早期單詞,可以提升循環神經網絡的性能。
2017 年,谷歌發布了《Attention Is All You Need》[12]這篇論文,它被譽為機器學習史上最重要的論文之一。在 Bahdanau 及其團隊的研究基礎上,谷歌的研究人員摒棄了 RNN 及其隱藏狀態的概念。他們采用的模型利用注意力機制來掃描先前的單詞,以獲取相關的上下文信息。
這種被谷歌命名為 transformer 的新架構,其重要性不言而喻,因為它消除了擴展語言模型的一個關鍵障礙。
以下動畫展示了 RNNs 為何難以擴展:
圖片
在這個假想的 RNN2 中,神經網絡試圖預測句子中的下一個單詞,預測結果展示在圖表的頂部。這個神經網絡由三層組成,每層用一個矩形表示。它的處理方式是線性的:必須先完成對第一個單詞“How”的分析,然后將隱藏狀態傳遞回底層,網絡才能開始分析第二個單詞“are”。
這種限制在機器學習算法在 CPU 上運行時還不是大問題。但當人們開始利用 GPU 的并行計算能力時,RNN 的線性架構就成為了瓶頸。
transformer 通過讓神經網絡能夠同時“思考”輸入中的所有單詞,從而突破了這一限制:
圖片
如圖所示,基于 transformer 的模型進行的計算量與前面圖中的 RNN 模型相當。因此,在(單核)CPU 上,它的運行速度可能不會更快。但由于模型不需要在處理“are”、“you”或“doing”之前完成對“How”的分析,它可以同時處理這些單詞。這意味著在擁有多個并行執行單元的 GPU 上,它的運行速度可以大幅提升。
速度提升有多大?速度的潛在提升與輸入單詞的數量成正比。以我的動畫為例,transformer 模型處理四詞輸入的速度大約是 RNN 的四倍。而對于 LLMs,其輸入可能包含數千個單詞。因此,在強大的 GPU 支持下,基于 transformer 的模型速度可以比類似的 RNN 快出幾個數量級。
你可能會問,為何不能同時用多個文檔來訓練RNN—— 即在文檔層面而非單個單詞層面利用 GPU 的并行處理能力。
這是因為訓練過程的第二階段——反向傳播的限制。在這個過程中,訓練軟件會“逆向”工作,通過微積分調整模型的參數,以提高得出正確答案的概率。對于 RNN 來說,反向傳播需要從輸入的最后一個單詞反向追溯到第一個單詞。如下圖中紅色箭頭所示:
圖片
反向傳播需要保存前向傳遞中每一步的中間結果——也就是說,訓練軟件需要存儲圖表中每個矩形的輸出。對于大模型,這些數據占用的空間極大,以至于無法同時并行訓練大量實例。3
簡而言之,transformer 釋放了 GPU 的全部處理能力,推動了語言模型規模的飛速增長。領先 LLMs 的參數量從 2018 年的數億[13]增長到了 2020 年的數千億[14]。由于傳統的基于 RNN 的模型受到線性架構的限制,它們無法在 GPU 上高效訓練,因此無法達到如此龐大的規模。
3.Transformers 模型存在擴展問題
我曾提到,在本文動畫中,循環神經網絡與 transformer 模型“大致完成了相等的工作量”。然而,兩者的工作量并非完全一致。我們再來看看 transformer 模型的工作圖:
圖片
注意到各層間那些交錯的對角線箭頭了嗎?它們代表了注意力機制的運轉。基于 transformer 的語言模型在創造新 token 前,會“審視”之前每一個已有的標記,以確定哪些最為相關。
在較小規模的上下文中,這些比較的成本微不足道。例如,對于僅有 10 個、100 個甚至 1000 個 tokens 的上下文,這些成本并不構成負擔。但隨著上下文長度的增加,注意力機制的計算成本也隨之攀升。上下文越長,為了生成下一個 token,所需的注意力操作(以及相應的計算資源)就越多。
這導致了一個問題:注意力機制總的計算能力需求與 tokens 總數成二次方關系增長。舉例來說,如果一個 10 個 tokens 的提示詞需要 414,720 次注意力操作4,那么:
- 處理一個 100 個 tokens 的提示詞,將需要 4560 萬次注意力操作。
- 處理一個 1000 個 tokens 的提示詞,將需要 46 億次注意力操作。
- 處理一個 10000 個 tokens 的提示詞,將需要 4600 億次注意力操作。
這或許也解釋了為何當上下文超過 128,000 個 tokens 時,谷歌會對 Gemini 1.5 Pro 的收費翻倍。因為生成第 128,001 個 token 時,需要與前面 128,000 個 tokens 進行比較,其成本遠高于生成第一個、第十個或第一百個 token。
4.提升注意力的效率和可擴展性
研究者們投入了大量精力優化注意力機制。其中一條研究路徑旨在最大化單個 GPU 的運算效率。
我們在前文了解到,現代 GPU 包含了成千上萬的執行單元。但在 GPU 開始進行數學運算之前,它需要將數據從較慢的共享內存(即高帶寬內存)轉移到特定執行單元內更快的內存(即SRAM)。有時,GPU 在移動數據上耗費的時間甚至超過了執行計算的時間。
在一系列論文中[15][16][17],普林斯頓大學的計算機科學家 Tri Dao 及其合作者開發了 FlashAttention,這種計算注意力的方式能夠最大限度地減少慢速內存操作的需求。Dao 等人的工作顯著提升了現代 GPU 上 transformers 的表現。
另一條研究路徑則著眼于如何在多個 GPU 上高效擴展注意力。其中一篇被廣泛引用的論文介紹了環形注意力機制(ring attention)[18],它通過將 input tokens 分成塊,并將每個塊分配給不同的 GPU 來工作。之所以稱為環形注意力,是因為 GPU 被構想為一個環形結構,每個 GPU 將其數據傳遞給相鄰的 GPU。
這讓我想起了曾參加過的一堂交誼舞課,舞伴們圍成一圈,女性保持不動,而男性則輪換舞伴。最終,每個男性都能與每位女性共舞。環形注意力的原理與之類似。"女性"代表查詢(query)向量(描述每個 token 所“尋找”的內容),"男性"代表鍵(key)向量(描述每個 token 的特征)。鍵向量在一連串 GPU 中傳遞,依次與所有查詢向量相乘。
總的來說,環形注意力機制通過在多個 GPU 間分配計算任務,使得大語言模型(LLM)能夠處理更大的上下文窗口。然而,它并未降低單個注意力計算的成本。
5.RNN 能否卷土重來?
由于 RNN 擁有固定大小的隱藏狀態(hidden state),因此它不會存在與 transformer 相同的擴展難題。無論是生成第一個、第一百個還是第一百萬個 token,RNN 所需的計算資源都相差無幾。這一點,相較于基于注意力機制的模型,RNN 具有顯著優勢。
盡管在 transformer 問世后,RNN 的地位有所下滑,但研究者們并未放棄,他們繼續探索適合在現代 GPU 上訓練的 RNN 新版本。
今年 4 月,谷歌推出了一款名為 Infini-attention[19] 的新模型。這個模型可謂是 transformer 與 RNN 的“混血兒”。Infini-attention 像傳統 transformer 那樣處理最近的 tokens,利用注意力機制記住它們并召回它們。
不過,Infini-attention 并未試圖記住所有上下文中的 tokens。相反,它采用一種“壓縮記憶(compressive memory)”來存儲較舊的 tokens,這種方式與 RNN 的隱藏狀態有幾分相似。這種數據結構能夠完美地存儲和召回少量 tokens,但隨著 tokens 數量的增加,召回率也會越來越低。
然而,機器學習領域的 YouTube 紅人 Yannic Kilcher 對谷歌的這種做法并不感冒[20]。
“我非常愿意相信這個方法確實有效,也認同這是實現無限注意力的一種途徑,但我還是持懷疑態度,”Kilcher表示。“它使用的是一種邊走邊存的壓縮記憶方法,并沒有真正學會如何存儲,只是按照一種確定性的方式在存儲,這意味著我們對存儲的內容和方式幾乎沒有控制權。”
6.Mamba 會是未來嗎?
在復興循環神經網絡(RNN)的眾多嘗試中,Mamba 架構無疑是最引人注目的。它是 2023 年 12 月發表的一篇論文[21]中公布的一種架構,其開發者是計算機科學家 Tri Dao(他也是我之前提到的 FlashAttention 的負責人)和 Albert Gu。
與傳統的 RNN 一樣,Mamba 并不依賴于注意力機制。它擁有一個充當“記憶”角色的隱藏狀態。由于這個隱藏狀態的大小是固定的,因此即使輸入的提示詞更長,也不會增加 Mamba 處理每個 token 的成本。
我在三月著手撰寫這篇文章時,本打算深入剖析 Mamba 的架構。然而,到了五月,研究團隊推出了 Mamba-2[22],其架構較之初代 Mamba 有了顯著的改變。坦白說,我一直在努力理解初代 Mamba 的原理,而對于 Mamba-2 的工作機制更是尚未完全弄清。
然而,我們需要明白的是,Mamba 有潛力將 transformer 模型的性能和傳統 RNN 的效率結合起來。
在六月,Dao 和 Gu 與 Nvidia 的研究人員合作發表了一篇論文[23],對擁有 80 億參數的 Mamba 模型進行了評估。研究發現,Mamba 模型在多項任務中與同等規模的模型不相上下,但在“上下文學習”和“從上下文中提取信息”的能力上,Mamba 模型仍略遜一籌。
transformer 模型之所以擅長信息提取,是因為它們能夠“記住”上下文中的每一個 token —— 這也是為什么隨著上下文長度的增加,transformer 模型的效率會降低。而 Mamba 則試圖將整個上下文壓縮到一個固定大小的狀態中,這意味著在處理長上下文時,它不得不舍棄一部分信息。
Nvidia 團隊發現,通過采用一種混合架構,該架構將 24 個 Mamba 層與 4 個注意力層交錯排列,他們獲得了最佳性能。這種混合架構的表現優于單純的 transformer 模型或單純的 Mamba 模型。
模型需要一些注意力層來記住其早期上下文中的關鍵細節。但是,似乎只需要少量的注意力層就夠了;其余的注意力層可以由成本更低的 Mamba 層替換,而對模型的整體性能影響很小。
在八月,一家名為 AI21 的以色列初創公司發布了其 Jamba 1.5 系列模型[24]。其中最大版本的參數數量達到了 3980 億,使其在規模上與 Meta 的 Llama 405B 模型相當。Jamba 1.5 Large 模型的 Mamba 層數量是注意力層的七倍。因此,Jamba 1.5 Large 所需的內存遠少于 Meta 和其他公司的同類模型。例如,AI21 估計 Llama 3.1 70B 需要 80 GB 的內存來跟蹤 256,000 個上下文 token ,而 Jamba 1.5 Large 只需要 9 GB,這使得模型能夠在性能較弱的硬件上運行。
Jamba 1.5 Large 模型的 MMLU 得分為 80,顯著低于 Llama 3.1 70B 的 86 分。因此,按照這個標準,Mamba 并沒有完全超越 transformer 模型。然而,這可能并不是一個完全公平的比較。像 Meta 這樣的前沿實驗室在訓練數據和后訓練基礎設施上投入了大量資金,以在 MMLU 等基準測試中提高幾個百分點的性能。同樣的高強度優化可能會縮小 Jamba 與前沿模型之間的差距。
因此,雖然更長上下文窗口的好處顯而易見,但達到這一目標的最優策略尚不明確。短期內, AI 公司可能會繼續使用巧妙的效率和擴展技巧(如 FlashAttention 和 Ring Attention)來擴展標準的 LLMs。長期來看,我們可能會看到對 Mamba 以及其他無注意力架構的興趣日益增長。或者也許有人會提出一種全新的架構,使 transformers 過時。
但我確信,僅僅依靠擴大基于 transformers 的前沿模型規模并不是一個完整的解決方案。如果我們想要能夠處理數十億個 tokens 的模型——許多人都有這樣的需求,我們就需要跳出固有的思維模式,尋找新的方法。
1 有網絡消息指出,ChatGPT 起初設定的上下文窗口為 4,096 個 tokens,但發布后不久的一次實驗[25]顯示,它能夠記憶超過這個數量的信息。
2 在十年前,循環神經網絡(RNN)通常會包含編碼器和解碼器兩部分,而像 GPT-3 這樣的現代大語言模型(LLM)則只有解碼器。出于教學目的,我展示了一個與歷史不符的、僅包含解碼器的 RNN 模型,這樣可以更容易地與 GPT-3 等現代 LLM 進行對比。同樣的分析方法也適用于 2010 年代初的真實 RNN 模型,但那時的模型圖會更加復雜。
3 GPU 能夠將中間計算結果傳輸到它的大容量高帶寬內存中。但是,由于高帶寬內存(HBM)[26]的速度限制,這一操作并不會提升訓練速度。
4 這是我針對擁有 1750 億參數版本的 GPT-3 的一個初步估算,該模型包含 96 層,每層有 96 個注意力頭。因此,實際上每對 tokens 之間需要進行 9,216 次注意力計算。
5 Jamba 模型是一種混合專家模型,這意味著對于任何一個 token,只有網絡中的一部分(980 億參數中的 3980 億)會被激活和使用。
Thanks for reading!
Hope you have enjoyed and learned new things from this blog!
About the author
Timothy B. Lee
I write the newsletter Understanding AI and cohost the AI Summer podcast. Previously I was a reporter at Ars Technica, Vox, and the Washington Post. twitter.com/binarybits