新PyTorch API:幾行代碼實現不同注意力變體,兼具FlashAttention性能和PyTorch靈活性
理論上,注意力機制就是你所需要的一切。然而在實際操作中,我們還需要優化像 FlashAttention 這樣的注意力機制的實現。
盡管這些融合的注意力機制大大提高了性能,且支持長上下文,但這種效率的提升也伴隨著靈活性的喪失。對于機器學習研究人員來說,這就像是一種「軟件彩票」—— 如果你的注意力變體不適合現有的優化內核,你將面臨運行緩慢和 CUDA 內存不足的困境。
一些注意力變體包括因果注意力、相對位置嵌入、Alibi、滑動窗口注意力、PrefixLM、文檔掩碼、不規則張量、PagedAttention 等。更糟糕的是,人們通常希望將這些變體組合在一起!比如滑動窗口注意力 + 文檔掩碼 + 因果注意力 + 上下文并行,又比如 PagedAttention + 滑動窗口的組合。
下圖左側代表了當今的現狀 —— 一些掩碼 + 偏置 + 設置的組合已經有現成的內核實現。然而,各種選項的添加會導致設置呈指數級增長。更糟糕的是,這種方式不會支持新的注意力變體。
為了徹底地解決這個超立方體問題,PyTorch 團隊引入了 FlexAttention,一個新的 PyTorch API。
- FlexAttention 是一個靈活的 API,允許用戶使用幾行慣用的 PyTorch 代碼就能實現多個注意力變體。
- 團隊人員通過 torch.compile 將其降低到一個融合的 FlashAttention 內核中 ,生成了一個不會占用額外內存且性能可與手寫內核相媲美的 FlashAttention 內核。
- 利用 PyTorch 的自動求導機制自動生成反向傳播。
- 最后,PyTorch 團隊還可以利用注意力掩碼中的稀疏性,從而顯著改善標準注意力實現。
FlashAttention 1-3 版本的參與者 Tri Dao 對這項研究進行了轉發并評論:這項研究使得很多技術都融合在一起了。
FlexAttention
經典的注意力方程式如下:
代碼形式:
FlexAttention 形式如下,其通過接受用戶定義的函數 score_mod 來解決上述問題。
代碼形式:
此函數允許用戶在 softmax 之前修改注意力分數。研究人員發現,該函數最終足以滿足大多數用戶對注意力變體的需求。
具體而言,score_mod 如下:
要應用此函數,可以將其實現為:
for b in range (batch_size):
for h in range (num_heads):
for q_idx in range (sequence_length):
for kv_idx in range (sequence_length):
modified_scores [b, h, q_idx, kv_idx] = score_mod (scores [b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)
最終的 API 具有令人驚訝的表達能力。
Score Mod 示例
全注意力
在這種情況下,score_mod 無操作,它接受分數作為輸入,然后原樣返回它們。
然后端到端的使用。
相對位置編碼
一種常見的注意力變體是相對位置編碼。相對位置編碼不是對查詢和鍵中的絕對距離進行編碼,而是根據查詢和鍵之間的距離調整分數。
需要注意的是,與典型實現不同,這不需要具體化 SxS 張量。相反,FlexAttention 會在內核中動態計算偏差值,從而顯著提高內存和性能。
Soft-capping
Soft-capping 是 Gemma 2 和 Grok-1 使用的一種技術,在 FlexAttention 中,它的形式是這樣的:
Causal Mask
盡管雙向注意力很簡單,但在論文《Attention is All You Need》,以及其他的 LLM 中,它們的設置都是僅解碼器的注意力,其中每個 token 只能關注它之前的 token。如果用戶使用 score_mod API ,可以將其表示為:
Sliding Window + Causal
圖源:https://arxiv.org/abs/2310.06825
Mistral 一直在推廣滑動窗口注意力(也稱為局部注意力),它允許查詢 token 僅關注最近的 1024 個 token,通常與因果注意力一起使用。
研究者對帶有滑動窗口掩碼的 F.scaled_dot_product_attention 以及帶有因果掩碼的 FA2 進行基準測試。結果表明,FlexAttention 不僅明顯快于 F.scaled_dot_product_attention,也明顯快于帶有因果掩碼的 FA2。
性能
總體而言,FlexAttention 的性能幾乎與手寫的 Triton 內核一樣好。然而,由于 FlexAttention 具有通用性,因此會遭受輕微的性能損失。例如,用戶必須承受一些額外的延遲。
FlexAttention 在前向傳播中實現了 FlashAttention2 性能的 90%,在反向傳播中實現了 85%。FlexAttention 目前正在使用一種確定性算法,該算法比 FAv2 重新計算了更多的中間體,研究者計劃改進 FlexAttention 的反向算法,來縮小這一差距!