清華開源混合精度推理系統MixQ,實現大模型近無損量化并提升推理吞吐
一鍵部署LLM混合精度推理,端到端吞吐比AWQ最大提升6倍!
清華大學計算機系PACMAN實驗室發布開源混合精度推理系統——MixQ。
MixQ支持8比特和4比特混合精度推理,可實現近無損的量化部署并提升推理的吞吐。
圖片
△圖1 MixQ吞吐與已有開源工作比較
MixQ同時量化權重和激活,使用低精度張量核心(INT8/INT4 Tensor Core)實現推理加速;同時,MixQ提取激活中少量的離群值,使用高精度張量核心(FP16 Tensor Core)保持推理準確性,通過系統優化掩蓋高精度訪存開銷。
不僅保持推理的準確性,而且通過使用低精度算力有效提升吞吐,充分發揮硬件計算潛力(圖1)。
同時,研究團隊提供了基于VLLM和Tensorrt-LLM的混合精度推理,用戶可以方便地一鍵部署模型。
圖2 使用VLLM一鍵部署4比特和8比特混合精度量化并推理
△
MixQ已支持多個主流大模型LLaMA3,Qwen2,Baichuan2,ChatGLM等。據了解,目前MixQ開源技術已被清程極智等AI行業公司應用在實際產品中。
該工作同時于高性能計算領域頂級國際會議SC’24發表,第一作者清華大學博士后陳逸東、通訊作者為翟季冬教授。
圖片
研究背景:已有量化技術總結
量化的主要技術路線有兩條,第一條是權重量化。
權重量化的理論加速比是16/量化的比特數。例如,將模型壓縮成為4bit,那么理論加速比為16/4=4倍。
然而,當服務商面臨大量的用戶同時訪問時,權重量化的系統吞吐會低于FP16的吞吐,其主要原因是權重量化計算過程中將低精度權重恢復成FP16然后計算,這導致權重量化并不使用低精度算力,當場景表現為compute bound的時候,性能較低。
△圖3 用戶請求多權重量化吞吐低于FP16
第二條技術路線是量化權重和激活,使用低精度的張量核心來提升系統的吞吐。
直接將激活量化為低比特可能會出現較大的精度損失。其原因在于激活矩陣中存在離群值(圖4)。
一個有效的方法是SmoothQuant,主要思想是通過平滑激活矩陣來降低量化激活的誤差。
△圖4 激活矩陣中存在離群值
混合精度量化則是一類全新的量化方法,該方案先做了一個矩陣分解,對絕大部分權重和激活用低比特存儲,將離群值用FP16存儲,分別做矩陣乘法。
圖片
△圖5 混合精度量化示意圖
混合精度量化的一個優勢就是可以實現近乎無損精度的量化。使用混合精度量化的LlaMA模型在MMLU 20個領域上的數據集進行推理準確率測試表明,采用8bit混合精度量化后的準確率下降不到0.1%:
圖6 混合精度量化分類準確率
不過,此前已有的混合精度量化的系統的性能普遍不高,主要瓶頸在針對離群點進行查找、訪存和計算的開銷占比大。
以混合精度庫Bitsandbytes為例,實測試表明,Bitsandbytes在用戶請求數量為512時僅有1.08倍的加速。
圖7 Bitsandbytes的在LLaMA70B上的Kernel性能測試
圖8 Atomic operator是混合精度推理系統的瓶頸之一
那么,如何優化對離群點的查找、訪存和計算的開銷呢?
MixQ的解決方案
MixQ的核心思想是基于離群點的局部性對混合精度的計算圖做等價變換,使得變換后的混合精度的計算圖可以避免離群點查找的額外開銷;在此基礎上,通過圖層融合和設計高效的混合精度數據結構降低訪存開銷;最后通過CUTLASS生成高性能的混合精度算子,達到提升系統性能的效果。
MixQ的設計基于以下的觀察:
離群點的局部性。對LLM的激活矩陣分析發現,在不同的decode階段的離群點的分布是有規律的。
如圖9,紅色的點表示的是第一次出現的離群點,綠色的點表示的是重復出現的離群點,隨著decode的進行,多數離群點出現在了固定的channel。
因此,研究人員得到一個重要的結論:在大部分的decode階段是不需要重復檢測離群點的,也就是說我們可以避免檢查離群點的開銷。
剩下的問題是,如何知道哪些時候不需要重復檢查離群點呢?這個答案就隱藏在量化系數中。
在量化的過程中需要對矩陣進行amax的操作。因此,通過amax得到的結果可以判斷矩陣中是否存在離群點。如amax的值大于閾值,那矩陣中存在離群點。反之則不存在。
更重要的是,amax操作可以和前一個操作融合。這樣不僅以極低的代價檢測離群點的存在,還通過對圖層進行融合來降低量化的開銷。
基于以上的分析,MixQ的設計使用了三個關鍵技術:
一是對計算圖的等價變換。
針對混合精度的計算邏輯進行了等價變換以后,通過計算激活矩陣的amax的值,避免了檢測離群點的開銷。
圖片
圖10 優化混合精度的計算邏輯
二是設計混合精度數據結構。
MixQ將離群點“拼接”成了一個新的矩陣。這一方法相較于ATOM采用的重排列(reorder)具有更低的開銷。
圖11 MixQ:order-reserved數據結構
三是使用CUTLASS編寫高性能的混合精度的算子,這一關鍵技術的實現依賴于NVIDIA提供的高性能矩陣乘法模板CUTLASS 3.x。
MixQ在寄存器中反量化低精度的計算結果并與高精度的結果進行相加。
圖12 融合dequantize、scale和add操作
下面來看MixQ的實驗結果,以LLaMA 70B為例。
在準確率表現方面,MixQ的準確率和Bitsandbytes一致。
圖13 MixQ的推理精度
在性能表現方面,MixQ 8bit kernel是Bitsandbytes的1.9倍。
MixQ 4bit Kernel的性能達724TFLOPs,是FP16的3.13倍。
圖片
△圖14 MixQ Kernel性能
端到端測試下,MixQ在batch=512相對Bitsandbytes和AWQ加速1.78和6倍。
圖片
圖15 多batch測試;上:MIXQ的推理輸出(19.21it/s);下:FP16的推理輸出 (1
項目地址:
[1]https://github.com/Qcompiler/MixQ_Tensorrt_LLM
[2]https://github.com/Qcompiler/MIXQ
[3]https://github.com/Qcompiler/vllm-mixed-precision