想把半本《紅樓夢》搬進ChatGPT輸入框?先把這個問題解決掉
過去兩年,斯坦福大學 Hazy Research 實驗室一直在從事一項重要的工作:增加序列長度。
他們有一種觀點:更長的序列將開啟機器學習基礎模型的新時代 —— 模型可以從更長的上下文、多種媒體源、復雜的演示等中學習。
目前,這項研究已經取得了新進展。Hazy Research 實驗室的 Tri Dao 和 Dan Fu 主導了 FlashAttention 算法的研究和推廣,他們證明了 32k 的序列長度是可能的,且在當前這個基礎模型時代將得到廣泛應用(OpenAI、Microsoft、NVIDIA 和其他公司的模型都在使用 FlashAttention 算法)。
- 論文地址:https://arxiv.org/abs/2205.14135
- 代碼地址:https://github.com/HazyResearch/flash-attention
正如 GPT4 的相關資料所指出的,它允許近 50 頁的文本作為上下文,而且像 Deepmind Gato 使用圖像作為上下文那樣實現 tokenization/patching。
在這篇文章中,作者介紹了關于在高層級上增加序列長度的新方法,并提供了連接一組新原語的「橋梁」。
Transformer 變得越來越深,越來越寬,但在長序列上訓練它們仍然很困難。研究人員遇到的一個基本問題是,Transformer 的注意力層在序列長度方面是按二次方比例增長:就是說從 32k 長度增加到 64k 長度,成本不只增加 2 倍,而是增加了 4 倍。因此,這促使研究人員探索具有線性時間復雜度的序列長度模型。在 Hazy Research 實驗室,這項工作從 Hippo 開始,然后是 S4、H3,再到現在的 Hyena。這些模型有可能處理數百萬、甚至十億級別的上下文長度。
FlashAttention 可以加速注意力并減少其內存占用 —— 無需任何近似。「自從我們在 6 個月前發布 FlashAttention 以來,我們很高興看到許多組織和研究實驗室采用 FlashAttention 來加速他們的訓練和推理。」博客中寫道。
FlashAttention 是一種對注意力計算進行重新排序并利用經典技術(平鋪、重新計算)加快速度并將內存使用從序列長度的二次減少到線性的算法。對于每個注意力頭,為了減少內存讀 / 寫,FlashAttention 使用經典的平鋪技術將查詢、鍵和值塊從 GPU HBM(其主內存)加載到 SRAM(其快速緩存),計算關于該塊的注意力,并將輸出寫回 HBM。在大多數情況下,這種內存讀 / 寫的減少帶來了顯著的加速(2-4 倍)。
接下來,讓我們看一下研究細節。
Long Range Arena 基準和 S4
谷歌的研究人員在 2020 年推出了 Long Range Arena (LRA) 基準測試,以評估不同模型處理長程依賴的能力。LRA 能夠測試一系列任務,涵蓋多種不同的數據類型和模式,例如文本、圖像和數學表達式,序列長度可達 16K(Path-X:對已展開成像素的圖像進行分類,沒有任何空間歸納偏置)。關于將 Transformer 擴展到更長的序列方面已經有很多出色的工作,但其中許多似乎會犧牲準確性(如下圖所示)。請注意 Path-X 那一列:所有 Transformer 方法及其變體表現甚至不如隨機猜測。
現在讓我們認識一下由 Albert Gu 主導研發的 S4。受到 LRA 基準測試結果的啟發,Albert Gu 想要找出如何更好地對長程依賴關系建模,在正交多項式和遞歸模型與卷積模型之間關系的長期研究基礎上,推出了 S4—— 一種基于結構化狀態空間模型(SSMs)的新的序列模型。
很關鍵的一點是,SSM 在將長度為 N 的序列拓展到 2N 時的時間復雜度為,而不像注意力機制一樣呈平方級別增長!S4 成功地對 LRA 中的長程依賴進行了建模,并成為首個在 Path-X 上獲得高于平均性能的模型(現在可以獲得 96.4%的準確度!)。自 S4 發布以來,許多研究人員在此基礎上發展和創新,出現了像 Scott Linderman 團隊的 S5 模型、Ankit Gupta 的 DSS(以及 Hazy Research 實驗室后續的 S4D)、Hasani 和 Lechner 的 Liquid-S4 等新模型。
另外,當 Hazy Research 發布 FlashAttention 時,已經能夠增加 Transformer 的序列長度。他們還發現,僅通過將序列長度增加到 16K,Transformer 也能在 Path-X 上獲得不凡的表現(63%)。
建模方面的不足
但是 S4 在語言建模方面的質量存在的差距高達 5% 的困惑度(對于上下文,這是 125M 模型和 6.7B 模型之間的差距)。為了縮小這一差距,研究人員研究了諸如聯想回憶之類的合成語言,以確定語言應該具備哪些屬性。最終設計了 H3(Hungry Hungry Hippos):一個堆疊兩個 SSM 的新層,并將它們的輸出與乘法門相乘。
使用 H3,Hazy Research 的研究人員替換了 GPT 式 Transformer 中的幾乎所有注意力層,并能夠在從 Pile 訓練的 400B 規模的 token 時,在困惑度和下游評估方面與 transformer 相媲美。
由于 H3 層建立在 SSM 上,因此在序列長度上,它的計算復雜度也以的速度增長。兩個注意力層使得整個模型的復雜度仍然是
,稍后會詳細討論這個問題。
當然,Hazy Research 不是唯一考慮這個方向的人:GSS 也發現帶有門控的 SSM 可以與語言建模中的注意力很好地協同工作(這啟發了 H3),Meta 發布了 Mega 模型,它也將 SSM 和注意力結合起來,BiGS 模型則替換了 BERT-style 模型中的注意力,而 RWKV 一直在研究完全循環的方法。
新進展:Hyena
根據前面的一系列工作,啟發 Hazy Research 的研究人員開發了新的架構:Hyena。他們試圖擺脫 H3 中最后兩個注意力層,并獲得一個幾乎呈線性增長的模型,以適應更長的序列長度。事實證明,兩個簡單的想法是找到答案的關鍵:
- 每個 SSM 都可以看作是一個長度與輸入序列相同的卷積濾波器。因此,可以用一個大小等于輸入序列的卷積來替換 SSM,以獲得在相同計算量下更加強大的模型。具體來說,通過另一個小型神經網絡來隱式地參數化卷積濾波器,這借鑒了關于神經場文獻中的強大方法和 CKConv/FlexConv 的研究成果。此外,卷積可以在 O (NlogN) 的時間內計算,其中 N 是序列長度,實現了近乎線性的擴展;
- H3 中的門控行為可以概括為:H3 采用輸入的三個投影,并迭代地進行卷積和應用門控。在 Hyena 中,只需添加更多投影和更多的門,這有助于泛化到更具表現力的架構并縮小與注意力的差距。
Hyena 首次提出了完全近線性時間卷積模型,它可以在困惑度和下游任務上與 Transformer 相匹配,并在實驗中取得了很好的結果。并且在 PILE 的子集上訓練了中小型模型,其表現與 Transformer 相媲美:
通過一些優化(更多內容見下文),在序列長度為 2K 時,Hyena 模型的速度略慢于相同大小的 Transformer,但在更長的序列長度上會更快。
接下來仍需思考的是,究竟能將這些模型推廣到什么程度?是否能將它們擴展到 PILE 的全尺寸(400B 個 token)?如果結合 H3 和 Hyena 的思想精華,會發生什么,能走多遠?
FFT 還是更基本的方法?
在所有這些模型中,一個常見的基本操作是 FFT,它是高效計算卷積的方式,只需要 O (NlogN) 的時間。然而,FFT 在現代硬件上的支持很差,因為現代硬件主流架構是專用的矩陣乘法單元和 GEMMs(例如 NVIDIA GPU 上的張量核心)。
可以通過將 FFT 重寫為一系列矩陣乘法操作來縮小效率差距。研究小組的成員利用蝴蝶矩陣來探索稀疏訓練,從而實現這個目標。最近,Hazy Research 研究人員利用這個連接構建了快速卷積算法,例如 FlashConv 和 FlashButterfly,通過使用蝴蝶分解將 FFT 計算轉化為一系列矩陣乘法操作。
此外,通過借鑒之前的工作,還能建立更深入的聯系:包括讓這些矩陣被學習,這同樣需要相同的時間,但會增加額外的參數。研究人員已經開始在一些小型數據集上探索這種聯系,并取得了初步成效。我們可以清楚地看到這種聯系可以帶來什么(比如,如何使其適用于語言模型):
這一擴展值得更深入的探索:這個擴展學習的是哪類轉換,它能讓你做什么?當將它應用于語言建模時會發生什么?
這些方向都是令人興奮的,接下來會是越來越長的序列和新的架構,讓我們能夠進一步探索這個新領域。我們需要特別關注那些能夠受益于長序列模型的應用,比如高分辨率成像、新的數據形式,能夠閱讀整本書的語言模型等等。想象一下,把整本書給語言模型閱讀,并讓它總結故事情節,或者讓一個代碼生成模型基于你寫的代碼來生成新的代碼。這些可能的場景非常非常多,都是讓人感到非常興奮的事情。