首個基于統計學的線性注意力機制ToST,高分拿下ICLR Spotlight
本文第一作者為加州大學伯克利分校三年級博士生吳梓陽,導師為馬毅教授。吳的主要研究方向為表征學習與多模態學習。該工作由多所學校與機構的研究者共同完成,包括加州大學伯克利分校、賓夕法尼亞大學、密歇根大學、清華大學、憶生科技、香港大學、約翰·霍普金斯大學等。據悉,馬毅教授已受邀在今年四月的ICLR大會上就和此項成果相關的一系列白盒神經網絡相關工作,進行為時一小時的主題報告(Keynote)。
Transformer 架構在過去幾年中通過注意力機制在多個領域(如計算機視覺、自然語言處理和長序列任務)中取得了非凡的成就。然而,其核心組件「自注意力機制」 的計算復雜度隨輸入 token 數量呈二次方增長,導致資源消耗巨大,難以擴展到更長的序列或更大的模型。
Token Statistics Transformer (ToST) 提出了一種新的注意力機制,它的時間復雜度是線性的。通過對序列特征的統計建模,ToST 提高了序列處理任務中的效率。文章探討了基于變分編碼率縮減(Variational Rate Reduction, VRR)的框架,并通過實驗驗證了其在不同任務中的性能,通過革新傳統注意力機制,解決了這些長期困擾 Transformer 架構的效率瓶頸。
ToST 也作為 Spotlight 論文,入選了 ICLR 2025 大會。
- 論文標題:Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction
- 論文地址:https://arxiv.org/abs/2412.17810
- 項目主頁:https://robinwu218.github.io/ToST/
- 目前該工作已開源:https://github.com/RobinWu218/ToST
研究背景與動機
一直以來,自注意力機制依賴于對輸入 token 兩兩相似性的計算,這一過程雖然有效,但其資源開銷顯著;尤其當輸入 token 數量極大時,傳統注意力機制(如 Transformer 中的全局注意力)在計算復雜度和內存使用上的瓶頸問題愈發顯著。
為了應對這一挑戰,本文提出了一種基于統計學特征的注意力機制:Token Statistics Self-Attention (TSSA)。它通過避免兩兩相似性的計算,僅依賴于 token 特征的統計量,顯著降低了計算復雜度。
Token Statistics Transformer (ToST) 的架構。Token Statistics Self-Attention (TSSA) 運算符通過對投影后的 token 進行行標量化變換,從而實現了線性復雜度。
核心方法
ToST 的核心方法是通過特定的概率分布函數對輸入序列進行建模,減少冗余信息并提取關鍵特征。具體包括:
1. 統計特征提取:對序列中的每個 token 提取其統計特征。
2. 變分編碼率縮減:利用 VRR 框架對特征進行壓縮,減少信息冗余。
3. 線性復雜度實現:通過一系列優化,其計算復雜度從 O (n2) 降低為 O (n)。
ToST 的方法概述。在 CRATE 的理論基礎上,ToST 通過幾何空間的結構化特征實現 token 分組和映射。
網絡架構的推導
該團隊通過擴展先前的 CRATE 工作推導出網絡架構。CRATE 顯示,一種 Transformer 風格的架構可以通過 "白盒" 架構設計自然生成,其中網絡的每一層都旨在實現最大編碼率縮減目標 (MCR2) 的增量優化步驟。
具體來說,該團隊推導了 MCR2 目標的一個新穎的變分形式,并表明通過對該變分目標進行展開梯度下降所得到的架構會引入一種新的注意力模塊,稱為 Token Statistics Self-Attention (TSSA)。TSSA 擁有線性的計算和內存復雜度,并從根本上不同于典型的注意力架構,其后者通過計算 token 之間的兩兩相似性來實現。
關鍵公式 MCR2 目標函數定義
技術細節
1. 線性時間注意力機制:Token Statistics Self-Attention (TSSA)
通過白盒設計方法(algorithmic unrolling),TSSA 從最大編碼率減少(Maximal Coding Rate Reduction, MCR2 )的變分形式中推導而來。
傳統 Transformer 依賴于 pairwise 相似度計算,而 TSSA 則基于 token 特征的統計量構建注意力機制,其計算復雜度從 O (n2) 降低為 O (n),內存占用同樣顯著減少。
2. 創新性的網絡結構:Token Statistics Transformer (ToST)
ToST 通過將 TSSA 替代標準的自注意力模塊,不僅實現了顯著的效率提升,還增強了模型的可解釋性。
與傳統模型不同,ToST 架構中的注意力操作基于統計量的低秩投影,通過減少不必要的計算路徑,大幅優化了資源使用。
3. 理論支撐與數學推導
基于 MCR2 的變分形式,提出了一種新穎的壓縮項公式,可對大型矩陣進行有效的特征提取。
通過設計數據相關的低秩投影,TSSA 在保留關鍵信息的同時,消除了冗余方向。
實驗驗證與性能分析
實驗覆蓋了自然言語處理(NLP)、計算機視覺(CV)等多個領域的任務,包括文本分類、機器翻譯、圖像識別等。結果表明,ToST 在保證模型性能的同時,大幅降低了計算資源消耗。
1. 計算和內存的線性復雜度分析
實驗結果顯示,與現有的注意力機制相比,TSSA 的時間和內存復雜度更低。具體而言,TSSA 的復雜度為 O (pn),顯著優于傳統 Transformer 的 O (n2)。
ToST 在計算時間和內存使用上均隨序列長度實現線性擴展,使其顯著優于標準 Transformer 的效率。如下:
復雜度分析對比
在 GPU 上評估的速度和內存使用對比
2. 視覺任務性能分析
在 ImageNet-1k 等主流視覺數據集上的實驗表明,ToST 的性能可與傳統 Transformer 架構(如 ViT 和 XCiT)相媲美,同時顯著減少了模型參數量和計算開銷。
遷移學習實驗中,ToST 在 CIFAR、Oxford Flowers 等數據集上的表現進一步驗證了其在多種視覺任務中的適應性。
結果展示了與傳統 Transformer 相當的性能,同時在計算效率上顯著更高。
3. 長序列任務和語言建模
- 長序列任務
在長序列任務基準測試(如 Long-Range Arena)中,ToST 展現出優異的長距離建模能力,其性能超越了現有 Transformer 變體。
- 語言建模
ToST 可以擴展并適用于多種任務場景,包括因果語言建模。針對語言建模,ToST 采用了一種因果版本的 TSSA,在多個數據集上實現了高效的預測能力。此外,即使在參數規模擴大的情況下,ToST 依然保持了優異的時間和內存效率。
NLP 任務中的表現
4. 有原理支持的模型設計
由于 ToST 是通過展開從學習目標中推導出來的,我們可以以有原理支持的方式逐層分析學習到的模型行為。
ToST 模型不同層次的 TSSA 輸出的變分壓縮項
5. 學習表示的可解釋性分析
ToST 通過統計量驅動的注意力機制,使每一層的注意力操作更加透明,便于解釋和分析。其分組機制展現了 token 特征在低維空間中的聚類效果,直觀反映了模型的決策過程。
ToST 在無需復雜的自監督訓練的情況下,自然生成了可解釋的注意力模式。
倒數第二個全局類注意力層中最后一個頭部的 [CLS] token 注意力圖的比較
在 TSSA 層中,可視化估計的隸屬矩陣 Π 的每一行(經過重塑后)
可能對未來產生的影響
1. 大模型的高效化
隨著語言模型、生成模型和多模態模型規模的持續擴展,計算效率成為核心瓶頸。ToST 展示的統計量驅動注意力機制,為實現線性復雜度的大模型提供了可能性。
2. 推動 Transformer 的普適化應用
高效的注意力機制使得 ToST 能夠更廣泛地應用于資源受限場景,如邊緣計算、實時系統、嵌入式設備等。這為人工智能技術從中心化計算向分布式、邊緣化方向的發展奠定了基礎。
3. 多模態融合的可能性
ToST 的低復雜度機制為處理多模態長序列任務提供了新的技術框架,使未來多模態大模型在生成、分析和交互中的效率顯著提升。
4. 促進跨學科應用
ToST 對數學理論與工程實現的有機結合,不僅在傳統 AI 任務中表現突出,還可能推動其在新興領域(如量子計算、生物信息學和材料設計)中的應用。
Token Statistics Transformer (ToST) 重塑了注意力機制,它不需要計算 token 之間的兩兩交互,而是基于投影后 token 特征的二階矩統計量構建,其基于數據壓縮和表示學習的理論原則目標,為 Transformer 的發展開辟了新路徑。其基于統計特性的低復雜度設計,不僅優化了現有架構的性能,還為未來大模型的高效化、多模態融合和跨學科應用提供了啟示。