譯者 | 朱先忠
審校 | 重樓
簡介
模型蒸餾是一種機器學習新技術,其基本思想是讓較小的模型(學生)模仿較大的模型(老師)的行為。當前,已經存在幾種方法可以實現這一技術(將在下文中展開具體介紹),但其目標都是在學生模型中獲得比從頭開始訓練更好的泛化能力。
模型蒸餾示例:學生(較小)模型使用蒸餾損失函數從教師模型中學習,該函數使用“軟標簽”和預測(使用OpenAI GPT4o生成的圖表)
一、為什么模型蒸餾很重要?
模型蒸餾是開發和部署大型語言模型(LLM)的關鍵技術。它解決了與這些模型的大小和復雜性相關的如下幾個挑戰:
- 資源效率:大型模型(例如具有1750億個參數的GPT-3)需要大量計算資源進行訓練和推理。這使得它們的部署和維護成本高昂。相反地,蒸餾可減小模型大小,從而降低內存使用量并加快推理時間,這對于硬件功能有限的應用程序尤其有益。
- 部署在邊緣設備上:許多應用程序需要實時處理和低延遲,尤其是在智能手機或物聯網設備等邊緣設備上運行的應用程序。精簡模型更加輕量,可以部署在此類設備上,從而無需依賴持續的云連接即可實現AI功能。
- 降低成本:由于能耗和專用硬件的需求,運行大型模型的成本很高。通過將模型蒸餾為較小的版本,公司或組織可以顯著降低運營費用,同時保持相當的性能水平。
- 提高訓練效率:通過蒸餾得到的較小模型需要更少的數據和計算能力來針對特定任務進行微調。這種效率加快了開發周期,使資源有限的研究人員和從業者更容易利用先進的人工智能模型。
為了說明模型蒸餾的影響,請考慮大型模型與其蒸餾模型之間的以下基準比較:
表1:大型語言模型與精簡模型(較小)之間的示例比較。注意:這些數字僅供參考,實際性能指標可能因實施和硬件而異
在此示例中,蒸餾模型實現了與GPT-3相當的準確率,同時顯著減少了參數數量和推理時間。這證明了蒸餾如何使AI模型在實際應用中更加實用且更具成本效益。
到目前為止,我們已經理解了為什么蒸餾如此重要。現在,讓我們更深入地了解模型蒸餾的細節。
二、什么是模型蒸餾?
想象一下,你正在向一位世界級專家(老師)學習一個復雜的主題,比如量子物理學。這位專家無所不知,但他們使用復雜的語言,需要很長時間才能解釋清楚。現在再想象一下,另一個人——一位偉大的溝通者(學生)——向這位專家學習,然后以一種更簡單、更快捷的方式教你相同的內容,而不會丟失核心信息。這就是模型蒸餾背后的主要思想。
更正式地說,模型蒸餾是一個過程,其中訓練一個較小、更高效的模型(稱為學生)來復制一個較大、更強大的模型(稱為老師)的行為。目標是讓學生更快、更輕松,同時在相同的任務上仍然表現良好。
蒸餾類型
模型蒸餾并不局限于教學生最終的答案是什么。學生可以通過多種方式向老師學習。以下是三種主要類型:
1.基于Logit的蒸餾(軟標簽):學生模型從老師的概率分布中學習
我們不只是對學生進行正確答案(硬標簽)的訓練,還讓它了解老師對每個答案的信心程度——這些被稱為軟標簽或軟目標。
為什么要使用軟目標?
假設你正在訓練一個模型來對動物進行分類:
- 輸入是一張狼的圖像。
- 老師輸出:
[Wolf: 1.0, Dog: 0.0, Cat: 0.0, Fox: 0.0]
- 學生不僅知道應該說“狼”,還知道狗和狐貍有些相似。這種額外的細微差別有助于學生學習更好的決策界限。
相比之下,僅使用硬標簽進行訓練將會是這樣的:
[Wolf: 1.0, Dog: 0.0, Cat: 0.0, Fox: 0.0]
沒有細微差別=更難學習細微的差別。
學生學到了什么:
- 哪些類別可能或不可能
- 老師如何處理不確定性
為什么這種類型很有用:
- 軟目標就像一個更平滑的訓練信號,特別是對于難以學習的例子
- 鼓勵學生更好地概括,而不是僅僅依靠硬標簽
這是最常用的方法,尤其是對于分類任務。
2.基于特征的蒸餾:學生模仿老師的中間層表征
學生被訓練模仿老師的隱藏層激活,而不僅僅是輸出。
你可以將教師模型視為在內部逐步解決復雜問題。在基于特征的蒸餾中,學生模型會嘗試復制教師解決問題的方式,而不僅僅是最終答案。
學生學到了什么:
- 內部推理模式
- 輸入數據的分層表示
- 嵌入結構和注意力圖(在Transformer中)
為什么這種類型很有用:
- 幫助學生學習更豐富的表現形式,尤其是在模型較小的情況下
- 可以提高對新任務的可轉移性
- 促進多模式或復雜模型的更好對齊
此類型用于TinyBERT和MobileBERT等蒸餾技術。
3.基于關系的蒸餾:捕獲多個實例之間的關系,而不僅僅是實例方面的知識
學生模型不僅從老師的輸出中學習,還從老師的表征空間中不同數據樣本之間的關系中學習。
基于關系的蒸餾并不關注個別的預測,而是教導學生保留數據的結構——例如,如果老師發現兩個句子相似,那么學生也應該學會將它們視為相似。
學生學到了什么:
- 實例之間的相對距離和相似性
- 嵌入空間中的分組、聚類或其他結構知識
為什么這種類型很有用:
- 在度量學習、對比學習或檢索任務中尤其有效
- 鼓勵學生學習與老師相同的“思維導圖”
- 使學生能夠適應輸入分布的變化
此類型用于更高級或以研究為重點的蒸餾方法(例如,RKD——關系知識蒸餾)。
小結
三、如何進行模型蒸餾(附代碼示例)
第1步:加載預訓練教師模型
使用Hugging Face的轉換器來加載大型模型(例如DistilBERT的老師:BERT)。
from transformers import AutoModelForSequenceClassification, AutoTokenizer
teacher_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_name)
第2步:從頭開始訓練學生模型
初始化一個較小的模型,例如:
distilbert-base-uncased。
student_model_name = "distilbert-base-uncased"
student_model =
AutoModelForSequenceClassification.from_pretrained(student_model_name)
第3步:實現知識蒸餾損失
在老師和學生的預測之間使用KL散度。
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, T=2.0):
hard_loss = F.cross_entropy(student_logits, labels)
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reductinotallow="batchmean"
) * (T ** 2)
return alpha * soft_loss + (1 – alpha) * hard_loss
第4步:使用教師的軟目標訓練學生模型
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
for batch in train_dataloader:
input_ids, attention_mask, labels = batch
with torch.no_grad():
teacher_logits = teacher_model(input_ids, attention_mask=attention_mask).logits
student_logits = student_model(input_ids, attention_mask=attention_mask).logits
loss = distillation_loss(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
第5步:評估蒸餾后的模型
比較原始模型和蒸餾模型的準確度和推理時間:
import time
def evaluate_model(model, dataloader):
model.eval()
correct, total = 0, 0
start_time = time.time()
with torch.no_grad():
for batch in dataloader:
input_ids, attention_mask, labels = batch
outputs = model(input_ids, attention_mask=attention_mask).logits
predictions = torch.argmax(outputs, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
inference_time = time.time() - start_time
accuracy = correct / total
return accuracy, inference_time
teacher_acc, teacher_time = evaluate_model(teacher_model, test_dataloader)
student_acc, student_time = evaluate_model(student_model, test_dataloader)
print(f"Teacher Accuracy: {teacher_acc:.4f}, Inference Time: {teacher_time:.2f}s")
print(f"Student Accuracy: {student_acc:.4f}, Inference Time: {student_time:.2f}s")
結論
隨著大型語言模型不斷突破AI的極限,它們也帶來了現實的弊端:推理速度慢、能耗高、部署能力有限。模型蒸餾通過將大型模型的功能壓縮為更小、更快、更高效的版本,為這些挑戰提供了一種實用而優雅的解決方案。
本文中,我們探索了模型蒸餾的緣由、內容和方式——從通過軟標簽學習到模仿內部表示,甚至保留數據點之間的關系。無論你是為移動應用程序、低延遲API還是邊緣設備構建模型,蒸餾都是在不降低性能的情況下縮小模型的關鍵工具。
蒸餾技術最重要的貢獻在哪里?在于這種技術不僅適用于研究實驗室或科技巨頭。借助Hugging Face Transformers和PyTorch等開源工具,任何人都可以立即開始蒸餾模型。
蒸餾不僅僅是為了讓模型更小,而且還為了讓它們更智能、更快、更易于訪問。隨著人工智能從集中式數據中心轉移到日常設備和應用程序,蒸餾只會變得越來越重要。
譯者介紹
朱先忠,51CTO社區編輯,51CTO專家博客、講師,濰坊一所高校計算機教師,自由編程界老兵一枚。
原文標題:Understanding Model Distillation in Large Language Models (With Code Examples),作者:Edgar