斯坦福博士獨作!大模型訓練速度再翻倍,還官宣加入明星創業公司當首席科學家
本文經AI新媒體量子位(公眾號ID:QbitAI)授權轉載,轉載請聯系出處。
現有大語言模型的訓練和推理速度,還能再快一點——
快多少?2-4倍。
各種大模型都在用的FlashAttention今天正式發布第2代并開源,所有Transformer架構的模型都可使用它來加速。
圖片
一代方法去年6月發布,無需任何近似即可加速注意力并減少內存占用。
現在,FlashAttention-2將它再度升級,使其核心注意力操作的速度再提高2倍,端到端訓練Transformer時的速度再提高1.3倍,并可在英偉達A100上訓練時實現72%的模型FLOP利用率(一般模型都在50%上下)。
圖片
鑒于現在煉一個大語言模型的成本高達數千萬美元,FlashAttention-2這一系列操作直接就能幫我們省掉數百萬(美元)!
網友驚得臟話都出來了(狗頭):
圖片
目前,這個項目已在GitHub上收獲4.4k標星。
與此同時,我們注意到,它的一作已經完成斯坦福博士學位并加盟大模型創業公司Together AI。
具體實現
據介紹,一代FlashAttention是一種對注意力計算重新排序的算法,它利用經典方法如tiling(切片)來顯著加快計算速度,并將序列長度的內存使用量從二次方減為線性。
其中tiling方法指的是將輸入塊從HBM(GPU內存)加載到SRAM(快速緩存),然后對該塊進行attention操作,再更新HBM中的輸出。
對HBM的反復讀寫就成了最大的性能瓶頸。
圖片
正是這種通過避免將大型中間注意力矩陣寫入HBM的方法,FlashAttention減少了內存讀/寫量,從而帶來2-4倍的時鐘時間加速。
然而,這個算法仍然存在一些低效率的問題,導致它仍然不如優化矩陣乘法 (GEMM) 運算來得快,最終僅達到理論最大FLOPs/s的25-40%(例如在A100上最多124 TFLOPs/s)。
究其原因,還是因為不同線程塊之間的工作和GPU上的wrap劃分不理想。
在此,FlashAttention-2進行了三方面的改進。
首先,在基礎算法上,減少非matmul(矩陣乘法) FLOP的數量。
一層原因是由于現代GPU具有專門的計算單元,matmul速度更快。例如A100上FP16/BF16 matmul的最大理論吞吐量為312TFLOPs/s,但非matmul FP32的理論吞吐量僅為19.5 TFLOPs/s。
另一層原因是價格考量,畢竟每個非matmul FLOP比matmul FLOP貴16倍。同時在matmul FLOP上花費盡可能多的時間也能保持高吞吐量。
為此,作者重寫了FlashAttention中的softmax trick,無需更改輸出即可減少重新縮放操作的數量,以及邊界檢查和因果屏蔽操作(causal masking operation)。
其次,當batch size較小時并行化以獲得更高的占用率。
FlashAttention一代在batch size和注意力頭數量上進行并行化。
由于它使用1個線程塊來處理1個注意力頭,總共就有(batch_size*注意力頭數)個線程塊,每個線程塊被安排在流式多處理器 (SM) 上運行。
當在像A100這樣有108個SM處理器上操作時,如果線程塊很多比如>=80,這樣的調度安排就很有效。
而在長序列的情況下,也就是batch size和頭數量很少(小)時,就需要在序列長度維度上另外進行并行化來更好地利用GPU上的多處理器了。
這個改進也是FlashAttention-2速度顯著提升的一大原因。
最后,改進工作分區。
在線程塊內,我們必須確定如何在不同的warp之間劃分工作。通常是每個塊使用4或8個warp,現在,作者改進了這一方式,來減少不同warp之間的同步和通信量,從而減少共享內存讀寫操作。
如下圖左所示,FlashAttention一代的做法是將K和V分割到4個warp上,同時保持Q可被所有warp訪問。這樣的后果是所有warp都需要將其中間結果寫入共享內存,然后進行同步再將中間結果相加,非常低效,減慢了FlashAttention中的前向傳播速度。
圖片
而在FlashAttention-2中,作者將Q分為四個warp,同時保證所有warp都可訪問K和V。
每個warp執行矩陣乘法獲得Q K^T的切片后,只需與V的共享切片相乘即可獲得相應的輸出。也就是說warp之間不需要通信,那么共享內存讀寫操作就少了很多,速度也就提上來了。
除了這三個大改進,FlashAttention-2還有兩個小改動:
一是注意力頭數從128增至256,這意味著GPT-J、CodeGen和CodeGen2以及StableDiffusion 1.x等模型都可以使用 FlashAttention-2來進行加速和內存節省了;
二是支持多查詢注意力(MQA)和分組查詢注意力(GQA)。
實驗評估
作者在A100 80GB SXM4 GPU上對不同配置(有無causal mask,頭數量64或128)下的運行時間進行了測量。
結果發現:
FlashAttention-2比FlashAttention(包括xformers庫和Triton中的其他實現)快大約2倍,這也意味我們可以用與之前訓練8k上下文模型相同的價格來訓練具有16k上下文的模型了(也就是模型上下文長度加倍)。
而與PyTorch中的標準注意力實現相比,FlashAttention-2的速度最高可達9倍。
圖片
此外,有了FlashAttention-2,我們只需在H100 GPU上運行相同的實現(不使用特殊指令利用TMA和第四代Tensor Core等新硬件功能),訓練速度就可以跑到高達335TFLOPs/s的成績。
圖片
以及當用于端到端訓練GPT式模型時,FlashAttention-2還能在A100上實現高達225TFLOPs/s的速度(模型FLOPs利用率達72%)。這與已經優化程序足夠高的FlashAttention相比,速度再提高了1.3倍。
圖片
一作加入大模型創業公司
FlashAttention-2論文僅顯示一位作者:Tri Dao。他也是FlashAttention一代的兩位共同作者之一。
圖片
據了解,Tri Dao的研究方向為機器學習和系統的交叉領域,去年拿下ICML 2022杰出論文亞軍獎。
最近他剛剛獲得斯坦福大學計算機科學博士學位,即將上升普林斯頓大學助理教授,并已宣布加盟生成式AI創業公司Together AI(該司主要目標構建一個用于運行、訓練和微調開源模型的云平臺)擔任首席科學家。
One More Thing
最后,有網友發現,除了FlashAttention-2,最近還有一系列類似成果,包括DeepSpeed的ZeRO++、馬薩諸塞大學de ReLoRA。
它們都是用于加速大型模型預訓練和微調,這些研究成果讓他覺得:
未來在低vram低帶寬的消費顯卡上訓練大模型,似乎已不是在做夢了。
圖片
大家認為呢?
論文地址:https://tridao.me/publications/flash2/flash2.pdf
博文地址:https://princeton-nlp.github.io/flash-atttention-2/
GitHub主頁:https://github.com/Dao-AILab/flash-attention