DeepSeek開源第三彈:V3/R1訓練推理關鍵秘籍,核心代碼僅300行
開源周的第三天,DeepSeek把訓練推理V3/R1背后的“動力”給亮出來了——
DeepGEMM:一個FP8 GEMM(通用矩陣乘法)庫,支持密集(dense)和混合專家(MoE)矩陣乘法運算。
圖片
我們先來簡單了解一下GEMM。
GEMM,即通用矩陣乘法,是線性代數中的基本運算,是科學計算、機器學習、深度學習等領域中“常客”,也是許多高性能計算任務的核心。
但由于它的計算量往往都比較大,所以GEMM的性能優化是至關重要的一點。
而DeepSeek這次開源的DeepGEMM,依舊是保持了“高性能+低成本”的特性,亮點如下:
- 高性能:在Hopper架構的GPU上,DeepGEMM能夠實現高達1350+FP8 TFLOPS的性能。
- 簡潔性:核心邏輯僅約 300 行代碼,但性能卻優于專家調優的內核。
- 即時編譯(JIT):采用完全即時編譯的方式,這意味著它可以在運行時動態生成優化的代碼,從而適應不同的硬件和矩陣大小。
- 無重依賴:這個庫設計得非常輕量級,沒有復雜的依賴關系,可以讓部署和使用變得簡單。
- 支持多種矩陣布局:支持密集矩陣布局和兩種 MoE 布局,這使得它能夠適應不同的應用場景,包括但不限于深度學習中的混合專家模型。
簡單來說,DeepGEMM主要用于加速深度學習中的矩陣運算,特別是在大規模模型訓練和推理中,它特別適用于需要高效計算資源的場景,能夠顯著提升計算效率。
很多網友們對這次的開源都比較“買單”,有人將DeepGEMM比作數學界的超級英雄,認為它比飛快的計算器還要快,比多項式方程還要強大。
圖片
也有人將DeepGEMM的發布比喻為量子態穩定到一個新的現實,稱贊其即時編譯的干凈利落。
圖片
當然……也有人開始擔心起自己手上的英偉達股票了……
圖片
深入了解DeepGEMM
DeepGEMM是一個專門為實現簡潔高效的FP8通用矩陣乘法(GEMMs)而打造的庫,它還具備細粒度縮放功能,這一設計源于DeepSeek V3。
它既能處理普通的通用矩陣乘法,也能支持MoE分組的通用矩陣乘法。
這個庫是用CUDA編寫的,安裝的時候不需要編譯,因為它會在運行時通過一個輕量級的即時編譯(JIT)模塊來編譯所有的內核程序。
目前,DeepGEMM只支持英偉達的Hopper張量核心。
為了解決FP8張量核心在計算累積時不夠精確的問題,它采用了CUDA核心的兩級累積(提升)方法。
雖然DeepGEMM借鑒了CUTLASS和CuTe里的一些理念,但并沒有過度依賴它們的模板或代數運算。
相反,這個庫設計得很簡潔,只有一個核心內核函數,代碼量大概300行左右。
這使得它成為一個簡潔易懂的資源,方便大家學習Hopper架構下的FP8矩陣乘法和優化技術。
盡管其設計輕巧,但DeepGEMM的性能可以匹配或超過各種矩陣形狀的專家調優庫。
那么具體性能如何呢?
團隊在H800上使用NVCC 12.8測試了DeepSeek-V3/R1推理中可能使用的所有形狀(包括預填充和解碼,但沒有張量并行)。
下面這張圖展示的是用于密集模型的普通DeepGEMM的性能:
圖片
從測試結果來看,DeepGEMM計算性能最高可達1358 TFLOPS,內存寬帶最高可達2668 GB/s。
加速比方面,與基于CUTLASS 3.6的優化實現相比,最高可達2.7倍。
再來看下DeepGEMM支持MoE模型的連續布局(contiguous layout)的性能:
圖片
以及支持MoE模型掩碼布局(masked layout)的性能是這樣的:
圖片
如何使用?
要想使用DeepGEMM,需先注意一下幾個依賴項,包括:
- 必須支持Hopper架構的GPU,sm_90a。
- Python 3.8及以上。
- CUDA 12.3及以上(推薦12.8)。
- PyTorch 2.1及以上。
- CUTLASS 3.6及以上
Development代碼如下:
# Submodule must be cloned
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop
# Test JIT compilation
python tests/test_jit.py
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py
安裝代碼如下:
python setup.py install
在上述步驟之后,您的Python項目中導入deep_gemm即可。
在接口方面,對于普通的DeepGEMM,可調用deep_gemm.gemm_fp8_fp8_bf16_nt函數,支持NT格式(非轉置LHS和轉置RHS)。
對于分組的DeepGEMM,連續布局情況下是m_grouped_gemm_fp8_fp8_bf16_nt_contiguous;掩碼布局情況下是m_grouped_gemm_fp8_fp8_bf16_nt_masked。
DeepGEMM還提供設置最大SM數量、獲取TMA對齊大小等工具函數;支持環境變量,如DG_NVCC_COMPILER、DG_JIT_DEBUG等。
除此之外,DeepSeek團隊還提供了幾種優化的方式,包括:
- JIT設計:所有內核在運行時編譯,無需安裝時編譯;支持動態選擇最優塊大小和流水線階段。
- 細粒度縮放:通過CUDA核心兩層累加解決FP8精度問題;支持非2的冪次方塊大小,優化SM利用率。
- FFMA SASS交錯:通過修改SASS指令的yield和reuse位,提高性能。
圖片
感興趣的小伙伴可以戳文末GitHub鏈接查看詳情哦~
One More Thing
英偉達這幾天的股票……嗯……一直再跌:
圖片
不過在北京時間27日凌晨,英偉達2025財年第四季度業績報告也即將出爐,我們可以期待一下它的表現~