比Transformer更好,無Attention、MLPs的BERT、GPT反而更強了
從 BERT、GPT 和 Flan-T5 等語言模型到 SAM 和 Stable Diffusion 等圖像模型,Transformer 正以銳不可當之勢席卷這個世界,但人們也不禁會問:Transformer 是唯一選擇嗎?
斯坦福大學和紐約州立大學布法羅分校的一個研究團隊不僅為這一問題給出了否定答案,而且還提出了一種新的替代技術:Monarch Mixer。近日,該團隊在 arXiv 公布了相關論文和一些檢查點模型及訓練代碼。順帶一提,該論文已入選 NeurIPS 2023 并獲得 Oral Presentation 資格。
論文地址:https://arxiv.org/abs/2310.12109
代碼地址:https://github.com/HazyResearch/m2
該方法去掉了 Transformer 中高成本的注意力和 MLP,代之以富有表現力的 Monarch 矩陣,使之在語言和圖像實驗中以更低的成本取得了更優的表現。
這并不是斯坦福大學第一次提出 Transformer 的替代技術。今年六月該校的另一個團隊還曾提出過一種名為 Backpack 的技術,參閱機器之心文章《斯坦福訓練 Transformer 替代模型:1.7 億參數,能除偏、可控可解釋性強》。當然,這些技術要取得真正的成功,還需要研究社區的進一步檢驗并在應用開發者手中變成切實好用的產品。
下面我們看看這篇論文中對 Monarch Mixer 的介紹以及一些實驗結果。
論文介紹
在自然語言處理和計算機視覺領域,機器學習模型已能處理更長的序列和更高維度的表征,從而支持更長的上下文和更高的質量。然而,現有架構的時間和空間復雜性在序列長度和 / 或模型維度上呈二次增長模式,這會限制上下文長度并提升擴展成本。舉個例子,Transformer 中的注意力和 MLP 會隨序列長度和模型維度呈二次擴展模式。
針對這一問題,斯坦福大學和紐約州立大學布法羅分校的這個研究團隊聲稱找到了一種高性能的架構,其復雜度隨序列長度和模型維度的增長是次二次的(sub-quadratic)。
他們的研究靈感來自 MLP-mixer 和 ConvMixer;這兩項研究觀察到:許多機器學習模型的運作方式都是沿序列和模型維度軸對信息進行混合,并且它們往往對兩個軸使用了單個算子。
尋找表現力強、次二次且硬件效率高的混合算子的難度很大。舉個例子,MLP-mixer 中的 MLP 和 ConvMixer 中的卷積都頗具表現力,但它們都會隨輸入維度二次擴展。近期有一些研究提出了一些次二次的序列混合方法,這些方法使用了較長的卷積或狀態空間模型,而且它們都會用到 FFT,但這些模型的 FLOP 利用率很低并且在模型維度方面依然是二次擴展。與此同時,不損質量的稀疏密集 MLP 層方面也有一些頗具潛力的進展,但由于硬件利用率較低,某些模型實際上可能還比密集模型更慢。
基于這些靈感,這個研究團隊提出了 Monarch Mixer (M2),其使用到了一類富有表現力的次二次結構化矩陣:Monarch 矩陣。
Monarch 矩陣是一類泛化了快速傅立葉變換(FFT)的結構化矩陣,并且研究表明其涵蓋了范圍廣泛的線性變換,包括哈達瑪變換、托普利茲矩陣、AFDF 矩陣和卷積。它們可通過分塊對角矩陣的積進行參數化,這些參數被稱為 Monarch 因子,與排列交織。
它們的計算是次二次擴展的:如果將因子的數量設為 p,則當輸入長度為 N 時,計算復雜度為 ,從而讓計算復雜度可以位于 p = log N 時的 O (N log N) 與 p = 2 時的 之間。
M2 使用了 Monarch 矩陣來沿序列和模型維度軸混合信息。這種方法不僅易于實現,而且硬件效率也很高:使用支持 GEMM(廣義矩陣乘法算法)的現代硬件就能高效地計算分塊對角 Monarch 因子。
該研究團隊實現了一個 M2 層來進行概念驗證 —— 完全使用 PyTorch 編寫,代碼行數不到 40(包括 import 軟件包),而且其只需依賴矩陣乘法、轉置、reshape 和逐元素乘積(見圖 1 中部的偽代碼);結果,對于大小為 64k 的輸入,這些代碼在一臺 A100 GPU 上實現了 25.6% 的 FLOP 利用率。在 RTX 4090 等更新的架構上,對于同樣大小的輸入,一個簡單的 CUDA 實現就能實現 41.4% 的 FLOP 利用率。
有關 Monarch Mixer 的更多數學描述和理論分析請參看原論文。
實驗
該研究團隊在 Transformer 已占主導地位的三個任務上對 Monarch Mixer 和 Transformer 進行了比較:BERT 風格的非因果掩碼語言建模任務、ViT 風格的圖像分類任務、GPT 風格的因果語言建模任務。
在每個任務上,實驗結果表明新提出的方法在不使用注意力和 MLP 的前提下均能達到與 Transformer 相媲美的水平。他們還在 BERT 設置中評估了新方法相較于強大 Transformer 基準模型的加速情況。
非因果語言建模
對于非因果語言建模任務,該團隊構建了一種基于 M2 的架構:M2-BERT。M2-BERT 可以直接替代 BERT 風格的語言模型,而 BERT 是 Transformer 架構的一大主力應用。對于 M2-BERT 的訓練,使用了在 C4 上的掩碼語言建模,token 化器則是 bert-base-uncased。
M2-BERT 基于 Transformer 骨干,但其中的注意力層和 MLP 被 M2 層替換,如圖 3 所示。
在序列混合器中,注意力被帶殘差卷積的雙向門控卷積替代(見圖 3 左側)。為了恢復卷積,該團隊將 Monarch 矩陣設置為 DFT 和逆 DFT 矩陣。他們還在投射步驟之后添加了逐深度的卷積。
在維度混合器中,MLP 中兩個密集矩陣被替換成了學習得到的分塊對角矩陣(1 階 Monarch 矩陣,b = 4)。
研究者預訓練了 4 個 M2-BERT 模型:其中兩個是大小分別為 80M 和 110M 的 M2-BERT-base 模型,另外兩個是大小分別為 260M 和 341M 的 M2-BERT-large 模型。它們分別相當于 BERT-base 和 BERT-large。
表 3 給出了相當于 BERT-base 的模型的性能表現,表 4 給出了相當于 BERT-large 的模型的性能表現。
從表中可以看到,在 GLUE 基準上,M2-BERT-base 的表現可以媲美 BERT-base,同時參數還少了 27%;而當兩者參數數量相當時,M2-BERT-base 勝過 BERT-base 1.3 分。類似地,參數少 24% 的 M2-BERT-large 與 BERT-large 表現相當,而參數數量一樣時,M2-BERT-large 有 0.7 分的優勢。
表 5 給出了相當于 BERT-base 的模型的前向吞吐量情況。其中報告的是在 A100-40GB GPU 上每毫秒處理的 token 數,這能反映推理時間。
可以看到,M2-BERT-base 的吞吐量甚至超過了經過高度優化的 BERT 模型;相較于在 4k 序列長度上的標準 HuggingFace 實現,M2-BERT-base 的吞吐量可達其 9.1 倍!
表 6 則報告了 M2-BERT-base (80M) 和 BERT-base 的 CPU 推理時間 —— 結果是直接運行這兩個模型的 PyTorch 實現得到的。
當序列較短時,數據局部性的影響依然主導著 FLOP 的減少情況,而過濾器生成(BERT 中沒有)等操作的成本更高。而當序列長度超過 1K 時,M2-BERT-base 的加速優勢就漸漸起來了,當序列長度達 8K 時,速度優勢可達 6.5 倍。
圖像分類
在非因果建模方面,為了驗證新方法在圖像上也有在語言上一樣的優勢,該團隊還評估了 M2 在圖像分類任務上的表現。
表 7 給出了 Monarch Mixer、ViT-b、HyenaViT-b 和 ViT-b-Monarch(用 Monarch 矩陣替換了標準 ViT-b 中的 MLP 模塊)在 ImageNet-1k 上的性能表現。
Monarch Mixer 優勢非常明顯:只需一半的參數量,其表現就能勝過原始 ViT-b 模型。而更讓人驚訝的是,參數更少的 Monarch Mixer 很能勝過 ResNet-152;要知道,ResNet-152 可是專門針對 ImageNet 任務設計的。
因果語言建模
GPT 風格的因果語言建模是 Transformer 的一大關鍵應用。該團隊為因果語言建模構建了一個基于 M2 的架構:M2-GPT。
對于序列混合器,M2-GPT 組合使用了來自 Hyena 的卷積過濾器、當前最佳的無注意力語言模型以及來自 H3 的跨多頭參數共享。他們使用因果參數化替換了這些架構中的 FFT,并完全移除了 MLP 層。所得到的架構完全沒有注意力,也完全沒有 MLP。
他們在因果語言建模的標準數據集 PILE 上對 M2-GPT 進行了預訓練。結果見表 8。
可以看到,盡管基于新架構的模型完全沒有注意力和 MLP,但其在預訓練的困惑度指標上依然勝過 Transformer 和 Hyena。這些結果表明,與 Transformer 大不相同的模型也可能在因果語言建模取得出色表現。
了解更多內容,請參考原論文。