模型蒸餾:“學神”老師教出“學霸”學生 原創 精華
編者按: 近日,Qwen 3 技術報告正式發布,該系列也采用了從大參數模型中蒸餾知識來訓練小參數模型的技術路線。那么,模型蒸餾技術究竟是怎么一回事呢?
今天給大家分享的這篇文章深入淺出地介紹了模型蒸餾的核心原理,即通過讓學生模型學習教師模型的軟標簽而非硬標簽,從而傳遞更豐富的知識信息。作者還提供了一個基于 TensorFlow 和 MNIST 數據集的完整實踐案例,展示了如何構建教師模型和學生模型,如何定義蒸餾損失函數,以及如何通過知識蒸餾方法訓練學生模型。實驗結果表明,參數量更少的學生模型能夠達到與教師模型相媲美的準確率。
作者 | Wei-Meng Lee
編譯 | 岳揚
Photo by 戸山 神奈 on Unsplash
如果你一直在關注 DeepSeek 的最新動態,可能聽說過“模型蒸餾”這個概念。但究竟什么是模型蒸餾?它為何重要?本文將解析模型蒸餾原理,并通過一個 TensorFlow 示例進行演示。通過閱讀這篇技術指南,我相信您將對模型蒸餾有更深刻的理解。
01 模型蒸餾技術原理
模型蒸餾通過讓較小的、較簡單的模型(學生模型)學習模仿較大的、較復雜的模型(教師模型)的軟標簽(而非原始標簽),使學生模型能以更精簡的架構繼承教師模型的知識,用更少參數實現相近性能。以圖像分類任務為例,學生模型不僅學習“某張圖片是狗還是貓”的硬標簽,還會學習教師模型輸出的軟標簽(如80%狗,15%貓,5%狐貍),從而掌握更細粒度的知識。 這一過程能在保持高準確率的同時大大降低模型體積和計算資源需求。
下文我們將以使用 MNIST 數據集訓練卷積神經網絡(CNN)為例進行演示。
MNIST 數據集(Modified National Institute of Standards and Technology)是機器學習和計算機視覺領域廣泛使用的基準數據集,包含 70,000 張 28x28 像素的手寫數字(0-9)灰度圖像,其中 60,000 張訓練圖像和 10,000 張測試圖像。
首先構建教師模型:
Image by author
教師模型是基于 MNIST 訓練的 CNN 網絡。
同時構建更輕量的學生模型:
Image by author
模型蒸餾的目標是通過更少的計算量和訓練時間訓練一個較小的學生模型,復現教師模型的性能表現。
接下來,教師模型和學生模型同時對數據集進行預測,然后計算二者輸出的 Kullback-Leibler (KL) 散度(將于后文進行詳述)。該數值(KL 散度)用于計算梯度,指導模型各層參數應該如何調整,從而指導學生模型的參數更新:
Image by author
訓練完成后,學生模型達到與教師模型相當的準確率:
Image by author
02 創建一個用于模型蒸餾的示例項目
現在,我們對模型蒸餾的工作原理已經有了更清晰的理解,是時候通過一個簡單的示例來了解如何實現模型蒸餾了。我將使用 TensorFlow 和 MNIST 數據集訓練教師模型,然后應用模型蒸餾技術訓練一個較小的學生模型,使其在保持教師模型性能的同時降低資源需求。
2.1 使用 MNIST 數據集
確保已安裝 TensorFlow:
下一步加載 MNIST 數據集:
以下是從 MNIST 數據集中選取的前 9 個樣本圖像及其標簽:
需要對圖像數據進行歸一化處理,并擴展圖像數據的維度,為訓練做好準備:
2.2 定義教師模型
現在我們來定義教師模型 —— 一個具有多個網絡層的 CNN(卷積神經網絡):
請注意,學生模型的最后一層有 10 個神經元(對應 10 個數字類別),但未使用 softmax 激活函數。該層直接輸出原始 logits 值,這在模型蒸餾過程中非常重要,因為在模型蒸餾階段會應用 softmax 計算教師模型與學生模型之間的 Kullback-Leibler(KL)散度。
定義完教師神經網絡后,需通過 compile() 方法配置優化器(optimizer)、損失函數(loss function)和評估指標(metric for evaluation):
現在可以使用 fit() 方法訓練模型:
本次訓練進行了 5 個訓練周期:
2.3 定義學生模型
在教師模型訓練完成后,接下來定義學生模型。與教師模型相比,學生模型的結構更簡單、層數更少:
2.4 定義蒸餾損失函數
接下來定義蒸餾損失函數,該函數將利用教師模型的預測結果和學生模型的預測結果計算蒸餾損失(distillation loss)。該函數需完成以下操作:
- 使用教師模型對當前批次的輸入數據進行推理,生成軟標簽「硬標簽:[0, 0, 1](直接指定類別3)。軟標簽:[0.1, 0.2, 0.7](表示模型認為70%概率是類別3,但保留其他可能性)。」;
- 使用學生模型預測計算其軟標簽;
- 計算教師模型與學生模型軟標簽之間的 Kullback-Leibler(KL)散度;
- 返回蒸餾損失。
軟標簽(soft probabilities)指的是包含多種可能結果的概率分布,而非直接分配一個硬標簽。例如在垃圾郵件分類模型中,模型不會直接判定郵件"是垃圾郵件(1)"或"非垃圾郵件(0)",而是輸出類似"垃圾郵件概率 0.85,非垃圾郵件概率 0.15"的概率分布。 這意味著模型有 85% 的把握認為該郵件是垃圾郵件,但仍認為有 15% 的可能性不是,從而可以更好地進行決策和閾值調整。
軟標簽使用 softmax 函數進行計算,并由溫度參數(temperature)控制分布形態。在知識蒸餾過程中,教師模型提供的軟標簽能幫助學生模型學習到數據集各類別間的隱含關聯,從而獲得更優的泛化能力和性能表現。
以下是 distillation_loss() 函數的具體定義:
Kullback-Leibler(KL)散度 (又稱相對熵)是衡量兩個概率分布差異程度的數學方法。
2.5 使用知識蒸餾方法訓練學生模型
現在我們可以通過知識蒸餾訓練學生模型了。首先定義 train_step() 函數:
該函數只執行了一個訓練步驟:
- 計算學生模型的預測結果
- 利用教師模型的預測結果計算蒸餾損失
- 計算梯度并更新學生模型的權重
要對學生模型進行訓練,需要創建一個訓練循環(training loop)來遍歷數據集,每一步都會更新學生模型的權重,并在每個 epoch 結束時打印損失值以監測訓練進度:
2.6 評估學生模型
訓練完成后,你可以使用測試集(x_test 和 y_test)評估學生模型的表現:
不出所料,學生模型的準確率相當高:
2.7 使用教師模型和學生模型進行預測
現在可以使用教師模型和學生模型對 MNIST 測試集的數字進行預測,觀察兩者的預測能力:
前兩個樣本的預測結果如下:
若測試更多數字圖像樣本,你會發現學生模型的表現與教師模型同樣出色。
03 Summary
在本文,我們探討了模型蒸餾(Model Distillation)這一概念,這是一種讓結構更簡單、規模更小的學生模型復現或逼近結構更復雜的教師模型的性能的技術。我們利用 MNIST 數據集訓練教師模型,然后應用模型蒸餾技術訓練學生模型。最終,層數更少、結構更精簡的學生模型成功復現了教師模型的性能表現,同時還大大降低了計算資源的需求。
希望這篇文章能夠滿足各位讀者對模型蒸餾技術的好奇心,也希望本文提供的示例代碼可以直觀展現該技術的高效與實用。
About the author
Wei-Meng Lee
ACLP Certified Trainer | Blockchain, Smart Contract, Data Analytics, Machine Learning, Deep Learning, and all things tech (??http://calendar.learn2develop.net??).
END
本期互動內容 ??
?除了模型蒸餾,剪枝和量化也是常用的模型壓縮方法。在你們的項目中,更傾向于采用哪些方法? 歡迎在評論區分享~
本文經原作者授權,由 Baihai IDP 編譯。如需轉載譯文,請聯系獲取授權。
原文鏈接:
??https://ai.gopubby.com/understanding-model-distillation-991ec90019b6??
