鄂維南院士領(lǐng)銜新作:大模型不止有RAG、參數(shù)存儲,還有第三種記憶
近年來,大型語言模型 (LLM) 因其非凡的性能而獲得了前所未有的關(guān)注。然而, LLM 的訓(xùn)練和推理成本高昂,人們一直在嘗試通過各種優(yōu)化方法來降低成本。
本文來自上海算法創(chuàng)新研究院、北京大學(xué)等機(jī)構(gòu)的研究者受人類大腦記憶層次結(jié)構(gòu)的啟發(fā),他們通過為 LLM 配備顯式記憶(一種比模型參數(shù)和 RAG 更便宜的記憶格式)來降低這一成本。從概念上講,由于其大部分知識都外化為顯式記憶,因而 LLM 可以享受更少的參數(shù)大小、訓(xùn)練成本和推理成本。
- 論文地址:https://arxiv.org/pdf/2407.01178
- 論文標(biāo)題:Memory3 : Language Modeling with Explicit Memory
作為初步的概念證明,研究者從零開始訓(xùn)練了一個 2.4B 的 LLM,它比更大的 LLM 和 RAG 模型獲得了更好的性能,并實(shí)現(xiàn)了比 RAG 更高的解碼速度。這個模型被命名為 Memory3,因為在 LLM 中,顯式記憶是繼隱式記憶(模型參數(shù))和工作記憶(上下文鍵值)之后的第三種記憶形式。
具體而言,本文引入了一種新的記憶格式,即顯式記憶,其特點(diǎn)是寫入成本和讀取成本相對較低。如圖 1 所示,模型首先將知識庫(或任何文本數(shù)據(jù)集)轉(zhuǎn)換為顯式記憶,實(shí)現(xiàn)為稀疏注意力鍵 - 值,然后在推理過程中調(diào)用這些內(nèi)存并將其集成到自注意力層中。
新的記憶格式定義了新的記憶層次結(jié)構(gòu):
此外,本文還介紹了一種支持知識外化的記憶電路理論,并提出了可以讓存儲易于處理的記憶稀疏機(jī)制和促進(jìn)記憶形成的兩階段預(yù)訓(xùn)練方案。
總結(jié)而言:
- Memory3 在推理過程中利用顯式記憶,減輕了模型參數(shù)記憶特定知識的負(fù)擔(dān);
- 顯式記憶是從構(gòu)建的知識庫中編碼而來的,其中稀疏記憶格式保持了真實(shí)的存儲大小;
- 研究者從頭開始訓(xùn)練了一個具有 2.4B 非嵌入?yún)?shù)的 Memory3 模型,其性能超過了更大規(guī)模的 SOTA 模型。它還比 RAG 具有更好的性能和更快的推理速度;
- 此外,Memory3 提高了事實(shí)性并減輕了幻覺,并能夠快速適應(yīng)專業(yè)任務(wù)。
方法介紹
記憶電路理論有助于確定哪些知識可以存儲為顯式記憶,以及哪種模型架構(gòu)適合讀取和寫入顯式記憶。
研究者將輸入輸出關(guān)系作為電路的內(nèi)部機(jī)制,并將知識定義為輸入輸出關(guān)系及其電路。通過操縱這些電路,人們可以從 LLM 中分離出許多知識,同時保持其功能完好無損。
Memory3:在架構(gòu)方面,本文的目標(biāo)是為 Transformer LLM 設(shè)計一個顯式的記憶機(jī)制,使其寫入成本和讀取成本都比較低。此外,本文希望將對 Transformer 架構(gòu)的修改限制在盡可能小的范圍內(nèi),不添加任何新的可訓(xùn)練參數(shù),這樣大多數(shù)現(xiàn)有的 Transformer LLM 都可以在幾乎不進(jìn)行微調(diào)的情況下轉(zhuǎn)換為 Memory3 模型。簡單的設(shè)計過程如下:
寫入成本:在推理之前,LLM 將每個參考寫入顯式記憶,保存在驅(qū)動器上。記憶是從自注意力層的鍵值向量中選擇的,因此寫入過程不涉及訓(xùn)練。每個引用都是獨(dú)立處理的,避免了長上下文注意力的成本。
讀取成本:在推理過程中,顯式記憶從驅(qū)動器中檢索,并與通常的上下文鍵值一起由自注意力讀取。每個記憶由來自少量注意力頭的極少量鍵值組成,從而大大減少了額外的計算、GPU 存儲、驅(qū)動器存儲和加載時間。它允許 LLM 頻繁檢索許多參考,而對解碼速度的影響有限。
推理過程如圖 9 所示,每當(dāng) LLM 生成 64 個 token 時,它就會丟棄當(dāng)前記憶,使用這 64 個 token 作為查詢文本來檢索 5 個新記憶,并繼續(xù)使用這些記憶進(jìn)行解碼。同樣,在處理提示時,LLM 會為每 64 個 token 塊檢索 5 個記憶。每個塊都會關(guān)注自己的記憶,并且不同塊之間的記憶可能會有所不同。
寫入與讀取記憶:在推理過程中,LLM 可以通過其自注意力層直接讀取檢索到的顯式記憶,方法是將它們與上下文鍵值連接起來(圖 9)。具體來說,對于第 l 層的每個注意力頭 h,如果它被選為記憶頭,那么它的輸出 Y^( l,h ) 將會改變:
此外,該研究對所有顯式記憶采用并行位置編碼,即所有鍵位置都位于長度為 128 的同一區(qū)間內(nèi),如圖 9 所示。
兩階段預(yù)訓(xùn)練:預(yù)訓(xùn)練由兩個階段組成,warmup 和持續(xù)訓(xùn)練。只有持續(xù)訓(xùn)練階段涉及顯式記憶,而 warmup 階段使用與普通預(yù)訓(xùn)練相同的格式。
圖 13 繪制了 warmup 階段訓(xùn)練損失和學(xué)習(xí)率時間表。
圖 14 繪制了持續(xù)訓(xùn)練階段訓(xùn)練損失和學(xué)習(xí)率時間表。
實(shí)驗結(jié)果
研究者評估了 Memory3 模型的一般能力(基準(zhǔn)任務(wù))、對話能力、專業(yè)能力(法律和醫(yī)學(xué))以及幻覺。此外,研究者還測量了 Memory3 的解碼速度,并與類似和更大的 SOTA LLM 以及 RAG 模型進(jìn)行了比較。
一般能力的評估結(jié)果如下所示,結(jié)果表明顯式記憶使平均分提高了 2.51%。相比之下,Llama2-7B 與 13B 的得分差距為 4.91%。顯式記憶可以將「有效模型大小」提高 2.51/4.91 ≈ 51.1%。
接下來作者評估了 Memory3 的對話技巧,結(jié)果列于表 18 中,表明模型以更少的參數(shù)勝過 Vicuna-7B、Falcon-40B-Instruct 和 ChatGLM2-6B。
目前,LLM 仍然面臨幻覺問題。從概念上講,Memory3 應(yīng)該不太容易受到幻覺的影響,因為它的顯式記憶直接對應(yīng)于參考文本。為了評估幻覺,研究者選擇了兩個英文數(shù)據(jù)集進(jìn)行評估。結(jié)果如表 19 所示,Memory3 在大多數(shù)任務(wù)上都取得了最高分。
使用顯式記憶的一個好處是,LLM 可以通過更新其知識庫輕松適應(yīng)新領(lǐng)域和任務(wù)。只需將與任務(wù)相關(guān)的參考導(dǎo)入 Memory3 的知識庫,并可選擇在熱啟動的情況下將其轉(zhuǎn)換為顯式記憶。然后,該模型可以利用這些新知識進(jìn)行推理,跳過成本更高且可能有損的微調(diào)過程,并且運(yùn)行速度比 RAG 更快。圖 4 已證明這種成本降低,并且可以促進(jìn) LLM 在各個行業(yè)的快速部署。
下表表明,Memory3 的表現(xiàn)優(yōu)于大多數(shù)模型。
最后,研究者通過每秒生成的 token 數(shù)來評估 Memory3 的解碼速度或吞吐量。
了解更多內(nèi)容,請參考原論文。