Meta「輕量級」KernelLLM顛覆GPU內核生成,8B參數碾壓GPT-4o
在AI領域,參數規模曾被視為「性能天花板」。
Meta最新發布的KernelLLM,卻用8B參數的「小身板」,在GPU內核生成任務中把200B的GPT-4o按在地上摩擦。
這是一個基于Llama 3.1 Instruct進行微調的8B參數模型,旨在將PyTorch模塊自動轉換為高效的Triton GPU內核。
圖片
KernelLLM簡直是GPU內核開發神器,用更少的參數實現了更強的性能,且簡單易用。
它只有8B參數,但是在KernelBench-Triton Level 1,單次推理性能超過了GPT-4o和DeepSeek V3。
通過多次推理,KernelLLM性能優于DeepSeek R1。
圖片
這一切都來自一個參數規模比競爭對手小兩個數量級的模型。
@Denis Kanonik吐槽「這又是用測試集訓練的嗎?」
圖片
KernelLLM讓內核開發更易上手
KernelLLM是一款基于Llama 3.1 Instruct的8B模型,專門針對用Triton編寫GPU內核的任務進行了訓練。
它能讓GPU編程變得更簡單,實現高性能GPU內核生成的自動化。
KernelLLM通過自動化生成高效的Triton實現,滿足對高性能GPU內核日益增長的需求。
隨著工作負載的增大和加速器架構的多樣化,對定制化內核解決方案的需求顯著增加。
現在市面上很多相關工具,要么只能在測試的時候優化,要么就只盯著KernelBench的問題調優,很難應對更廣泛的場景。
KernelLLM是首個在外部(PyTorch,Triton)代碼對數據上進行微調的LLM。
Triton內核生成工作流程
把PyTorch代碼輸進去,KernelLLM就會生成Triton內核候選代碼。
然后用單元測試來驗證這些代碼,用隨機輸入跑一跑,看看輸出對不對。要是生成好幾個候選代碼,還能比比哪個最好,挑出最優的。
圖片
KernelLLM的Triton內核生成流程:用KernelLLM把PyTorch代碼翻譯成Triton內核的候選代碼。生成的代碼會通過單元測試驗證,測試用已知形狀的隨機輸入數據運行內核。這個流程支持生成多個候選代碼(通過 pass@k評估),增加候選數量來提高質量,最后選出最好的Triton內核實現作為輸出(綠色部分)
為了訓練這個模型,團隊可是下了大功夫,用了25000多對(PyTorch,Triton)代碼示例,還有合成的樣本。
這些數據一部分來自TheStack的過濾代碼,一部分是通過torch.compile () 和提示技術生成的。
數據集KernelBook,參考鏈接:https://huggingface.co/datasets/GPUMODE/KernelBook。
訓練時用的是Llama3.1-8B-Instruct模型,在自定義數據集上做了監督微調(SFT),測試它在KernelBench-Triton上生成正確Triton內核及調用代碼的能力。
KernelBench-Triton是基于KernelBench[Ouyang et al. 2025]開發的變體,專注Triton內核生成。
訓練和評估時,PyTorch代碼會配置一個包含格式示例的提示模板作為指令。
模型訓練了10個epoch,批大小為32,采用標準SFT方法,超參數根據驗證集的困惑度(perplexity)來選擇。
訓練用了16個GPU,共耗時12小時(192 GPU小時),報告了最佳檢查點的驗證結果。
性能評估
盡管模型規模較小,但其性能可與最先進的LLM相媲美。
圖片
KernelBench-Triton測試中,8B參數的KernelLLM,單次推理得分20.2,比671B參數的DeepSeek V3(16分)和200B參數的GPT-4o(15分)都高。
圖片
要是多生成幾個候選代碼,得分還能蹭蹭往上漲,生成10個的時候能到51.8分,20個的時候能到57.1分。
KernelLLM推理用temperature=1.0和top_p=0.97運行。
在KernelBench上測試了模型,這是一個開源基準測試,用于評估LLM編寫的高效GPU內核的能力。
它包含250個精心挑選的PyTorch模塊,按負載調整,從簡單的單操作(如Conv2D或Swish,Level 1)到完整的模型架構(Level 3)。
它在不同難度的任務里表現都很穩,不管是簡單的單個操作符,還是復雜的模型架構,都能應對。
測試會同時降低代碼的正確性(通過與參考PyTorch輸出對比)和性能(通過與基準實現的加速比)。
團隊開發了一個新的KernelBench-Triton變體,專門評估LLM生成Triton內核的能力,非常適合測試KernelLLM。
所有測試都在NVIDIA H100 GPU上完成。
圖片
KernelLLM在pass@k中表現出近似對數線性的擴展行為
KernelLLM怎么用?
先裝幾個依賴包:
pip install transformers accelerate torch triton
pip install transformers accelerate torch triton
用的時候,先導入庫,調用generate_triton函數,就能生成優化后的Triton代碼啦。
KernelLLM提供了一個簡單的接口,用于從PyTorch代碼生成Triton核。
from kernelllm import KernelLLM# Initialize the modelmodel = KernelLLM()# Define your PyTorch modulepytorch_code = '''import torchimport torch.nn as nnclass Model(nn.Module): """ A model that computes Hinge Loss for binary classification tasks. """ def __init__(self): super(Model, self).__init__() def forward(self, predictions, targets): return torch.mean(torch.clamp(1 - predictions * targets, min=0))batch_size = 128input_shape = (1,)def get_inputs(): return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]def get_init_inputs(): return []'''# Generate optimized Triton codeoptimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)print(optimized_code)
from kernelllm import KernelLLM
# Initialize the model
model = KernelLLM()
# Define your PyTorch module
pytorch_code =
'''
import torch
import torch.nn as nnclass Model(nn.Module):
"""
A model that computes Hinge Loss for binary classification tasks.
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, predictions, targets):
return torch.mean(torch.clamp(1 - predictions * targets, min=0))
batch_size = 128
input_shape = (1,)
def get_inputs():
return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]
def get_init_inputs():
return []
'''
# Generate optimized Triton code
optimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)
print(optimized_code)
要是不想寫腳本,還能直接運行python kernelllm.py,使用內置的REPL接口,打開交互式界面,實時看結果。
kernelllm.py提供了多種與模型交互的方法。
python kernelllm.py
python kernelllm.py
KernelLLM提供了幾種自定義生成過程的方法:
from kernelllm import KernelLLMmodel = KernelLLM()# Stream output in real-timemodel.stream_raw("Your prompt here", max_new_tokens=2048)# Generate raw text without the Triton-specific prompt templateraw_output = model.generate_raw("Your prompt here", temperature=1.0, max_new_tokens=2048)
from kernelllm import KernelLLM
model = KernelLLM()
# Stream output in real-time
model.stream_raw("Your prompt here", max_new_tokens=2048)
# Generate raw text without the Triton-specific prompt template
raw_output = model.generate_raw("Your prompt here", temperature=1.0, max_new_tokens=2048)
有時它會犯點小錯誤,比如API引用不對、語法出錯,有時候還不太能按指令生成理想的內核。
生成的代碼結構有點像編譯器自動吐出來的,有時在變量命名、張量形狀、類型處理和數值精度這些細節上也容易出問題。