√N并行+84倍計算加速!英偉達港大全新圖像注意力:空間結構都保留
Transformer 及其核心的注意力機制在自然語言處理和計算機視覺等領域帶來了革命性進展,展現出強大的深度上下文建模和數據間復雜依賴關系捕捉能力。
然而,其在處理視覺數據時面臨兩大核心挑戰:
- 二次計算復雜度使其難以高效處理高分辨率圖像等長上下文數據;
- 忽略空間結構,將多維圖像視為無結構的一維標記序列,破壞了圖像固有的空間連貫性,而這種信息對于依賴空間關系的視覺任務至關重要。
為克服效率瓶頸,近期研究如線性注意力和狀態空間模型(如 Mamba) 致力于將復雜度降低至線性。
然而,這些方法在提升效率的同時,依然未能有效保留和利用圖像的關鍵二維空間結構信息,本質上仍是序列化處理。
嘗試將一維光柵掃描(raster scan)擴展至二維的線掃描方法(line scan)是增強空間連貫性的一種思路。
但二維線性傳播面臨嚴峻挑戰:標量權重變為連接像素與前序鄰居的矩陣權重。在傳播過程中累積的矩陣乘法極易導致穩定性問題——矩陣特征值過大引發指數增長(不穩定),過小則導致信號迅速衰減(信息消失)。
因此,在二維空間中同時實現穩定性和維持長距離上下文的有效傳播,是一個亟待解決的難題。
針對上述挑戰,來自英偉達、香港大學和UCSD的研究人員提出廣義空間傳播網絡(GSPN),一種專為視覺任務優化的新型注意力機制,其核心優勢在于直接操作空間連貫的圖像數據,通過高效的線掃描方法建立密集的像素間連接。
論文地址:https://arxiv.org/abs/2501.12381
項目主頁:https://whj363636.github.io/GSPN/
代碼:https://github.com/NVlabs/GSPN
GSPN成功的關鍵在于其提出的穩定性-上下文條件(Stability-Context Condition),該條件確保了跨二維序列的穩定長上下文傳播,并將具有N個元素的圖像的復雜度顯著降低至√N 量級。
因此,GSPN能夠在保持卓越空間保真度的同時,實現極高的計算效率,并在ImageNet分類、類引導圖像生成及文本到圖像生成等任務中達到先進性能。例如,在生成16K圖像時,GSPN相比基于softmax注意力的SD-XL加速超過84倍。
論文第一作者為王弘焌,香港大學統計系博士三年級學生,目前為NVIDIA research intern,研究方向包括高效基礎模型、開放世界理解。
GSPN方法
二維線性傳播
二維線性傳播通過逐行或逐列的順序處理進行。對于二維圖像,其遵循線性循環過程,隱藏層通過前一行的隱藏狀態和當前輸入計算得出。
將隱藏狀態和輸入的行向量連接成序列后,可表示為輸入與一個下三角矩陣的乘積,輸出則為輸入的加權和,該公式可類比為帶因果掩碼的非歸一化線性注意力機制,其中額外的傳播矩陣調制注意力強度。
穩定性-上下文條件
在傳播過程中上述累積的矩陣乘法極易導致穩定性問題。
為實現穩定且有效的長距離傳播,研究人員引入定理1和定理2(統稱為穩定性-上下文條件)。
定理1指出,若所有矩陣均為行隨機矩陣,則滿足各元素加權和為1
定理2表明,行隨機矩陣可確保傳播過程的穩定性。行隨機矩陣的定義為元素非負且每行元素之和為1,乘積仍為行隨機矩陣,這為穩定傳播提供了數學基礎。
傳播層的關鍵實現
對于二維線性循環過程,研究人員對前序狀態的三鄰居連接來計算當前時刻的隱藏層(每個像素連接前一行的三個相鄰像素)以提高參數效率。
文中同時還提出GSPN的兩種變種,全局GSPN和局部GSPN:
全局GSPN捕捉整個序列的長距離依賴,局部GSPN通過將空間維度劃分為非重疊組來限制傳播序列長度,提高效率。
最后,通過四方向集成確保全像素連接,形成密集成對連接。
對每個傳播方向的矩陣元素應用 sigmoid 函數并歸一化,以保證行隨機性。
通過定制的CUDA內核實現線性傳播層,采用并行化結構,在批量、通道和與傳播方向正交的行/列上實現全并行化,有效減少內核循環長度,實現高效可擴展的線性傳播。
GSPN架構
GSPN是一個通用序列傳播模塊,可無縫集成到各種視覺任務的神經網絡中。針對判別任務和生成任務設計了不同的GSPN塊,均基于核心GSPN模塊構建:
- GSPN模塊:通過共享1×1卷積進行降維,再通過三個獨立的1×1卷積生成依賴于輸入的參數,用于二維線性傳播,這些投影和傳播封裝在模塊化的GSPN單元中。
- 圖像分類架構:采用Swin-Transformer的四級分層架構,通過堆疊設計良好的GSPN塊,在相鄰層級間進行下采樣操作,平衡計算效率和表示能力。
- 類條件圖像生成架構:重新設計生成架構,通過向量嵌入加法集成時間步和條件信息,包含跳躍連接和線性投影,去除位置嵌入并引入FFN進行通道混合。
- 文本到圖像生成架構:將GSPN模塊直接集成到Stable Diffusion架構中,替換所有自注意力層,利用預訓練權重初始化參數,加速訓練。
實驗結果
圖像分類
在ImageNet-1K分類任務中,GSPN在參數數量相當的情況下優于現有序列模型,GSPN在從小型到基礎配置的模型規模上表現出一致的性能提升,證明了其可擴展性。
類條件圖像生成
與多種基線方法相比,GSPN-XL/2在ImageNet 256×256類條件生成任務中建立了新的最先進性能,GSPN-L/2僅使用先前模型65.6%的參數就獲得了更優的FID和IS分數,GSPN-B/2在收斂時僅使用DiT-XL/2 20.3%的參數就實現了有競爭力的性能,驗證了GSPN的效率和可擴展性。
文本到圖像生成
GSPN由于其歸一化權重滿足穩定性-上下文條件,無需額外歸一化即可適應任意分辨率,在不使用任何預訓練權重且在相同訓練輪數內達到了與baseline相當的性能。
此外,GSPN在單塊A100 GPU上生成16K×8K分辨率圖像可實現約84倍的加速。
總結
研究人員提出了廣義空間傳播網絡(GSPN),這是一種用于視覺任務中并行序列建模的新型注意力機制。
通過穩定性-上下文條件確保穩定且上下文感知的傳播,GSPN在保持效率的同時將序列復雜度減少到√N
實驗表明,GSPN在多個視覺任務中實現了最先進的結果和顯著的加速,展示了其在視覺任務中的效率和潛力。
未來,GSPN有望在更多視覺領域及視覺多模態模型中發揮重要作用,推動下一代視覺理解和生成基礎結構的發展。