大模型精準反哺小模型,知識蒸餾助力提高 AI 算法性能
01 知識蒸餾誕生的背景
來,深度神經網絡(DNN)在工業界和學術界都取得了巨大成功,尤其是在 計算機視覺任務 方面。深度學習的成功很大程度上歸功于其具有數十億參數的用于編碼數據的可擴展性架構,其訓練目標是在已有的訓練數據集上建模輸入和輸出之間的關系,其性能高度依賴于網絡的復雜程度及有標注訓練數據的數量和質量。
相比于計算機視覺領域的傳統算法,大多數基于 DNN 的模型都因為 過參數化 而具備強大的 泛化能力 ,這種泛化能力體現在對于某個問題輸入的所有數據上,模型能給出較好的預測結果,無論是訓練數據、測試數據,還是屬于該問題的未知數據。
在當前深度學習的背景下,算法工程師為了提升業務算法的預測效果,常常會有兩種方案:
使用過參數化的更復雜的網絡,這類網絡學習能力非常強,但需要大量的計算資源來訓練,并且推理速度較慢。
集成模型,將許多效果弱一些的模型集成起來,通常包括參數的集成和結果的集成。
這兩種方案能顯著提升現有算法的效果,但都提升了模型的規模,產生了較大的計算負擔,需要的計算和存儲資源很大。
在工作中,各種算法模型的最終目的都是要 服務于某個應用 。就像在買賣中我們需要控制收入和支出一樣。在工業應用中,除了要求模型要有好的預測以外, 計算資源的使用也要嚴格控制,不能只考慮結果不考慮效率。在輸入數據編碼量高的計算機視覺領域,計算資源更顯有限,控制算法的資源占用就更為重要。
通常來說,規模較大的模型預測效果更好,但訓練時間長、推理速度慢的問題使得模型難以實時部署。尤其是在視頻監控、自動駕駛汽車和高吞吐量云端環境等計算資源有限的設備上,響應速度顯然不夠用。規模較小的模型雖然推理速度較快,但是因為參數量不足,推理效果和泛化性能可能就沒那么好。如何權衡大規模模型和小規模模型一直是一個熱門話題,當前的解決方法大多是 根據部署環境的終端設備性能選擇合適規模的 DNN 模型。
如果我們希望有一個規模較小的模型,能在保持較快推理速度的前提下,達到和大模型相當或接近的效果該如何做到呢?
在機器學習中,我們常常假定輸入到輸出有一個潛在的映射函數關系,從頭學習一個新模型就是輸入數據和對應標簽中一個 近似 未知的映射函數。在輸入數據不變的前提下,從頭訓練一個小模型,從經驗上來看很難接近大模型的效果。為了提升小模型算法的性能,一般來說最有效的方式是標注更多的輸入數據,也就是提供更多的監督信息,這可以讓學習到的映射函數更魯棒,性能更好。舉兩個例子,在計算機視覺領域中,實例分割任務通過額外提供掩膜信息,可以提高目標包圍框檢測的效果;遷移學習任務通過提供在更大數據集上的預訓練模型,顯著提升新任務的預測效果。因此 提供更多的監督信息 ,可能是縮短小規模模型和大規模模型差距的關鍵。
按照之前的說法,想要獲取更多的監督信息意味著標注更多的訓練數據,這往往需要巨大的成本,那么有沒有一種低成本又高效的監督信息獲取方法呢?2006 年的文獻[1]中指出,可以讓新模型近似(approximate)原模型(模型即函數)。因為原模型的函數是已知的,新模型訓練時等于天然地增加了更多的監督信息,這顯然要更可行。
進一步思考,原模型帶來的監督信息可能蘊含著不同維度的知識,這些與眾不同的信息可能是新模型自己不能捕捉到的,在某種程度上來說,這對于新模型也是一種“跨域”的學習。
2015年Hinton在論文《Distilling the Knowledge in a Neural Network》[2] 中沿用近似的思想,率先提出“ 知識蒸餾 (Knowledge Distillation, KD)”的概念:可以先訓練出一個大而強的模型,然后將其包含的知識轉移給小的模型,就實現了“保持小模型較快推理速度的同時,達到和大模型相當或接近的效果”的目的。這其中先訓練的大模型可以稱之為教師模型,后訓練的小模型則被稱之為學生模型,整個訓練過程可以形象地比喻為“師生學習”。隨后幾年,涌現了大量的知識蒸餾與師生學習的工作,為工業界提供了更多新的解決思路。目前,KD 已廣泛應用于兩個不同的領域:模型壓縮和知識遷移[3]。
02 Knowledge Distillation
簡介
Knowledge Distillation 是一種基于“教師-學生網絡”思想的模型壓縮方法,由于簡單有效,在工業界被廣泛應用。其目的是將已經訓練好的大模型包含的知識——蒸餾(Distill),提取到另一個小的模型中去。那怎么讓大模型的知識,或者說泛化能力轉移到小模型身上去呢?KD 論文把大模型對樣本輸出的概率向量作為軟目標(soft targets)提供給小模型,讓小模型的輸出盡量去向這個軟目標靠(原來是往 one-hot 編碼上靠),去近似學習大模型的行為。
在傳統的硬標簽訓練過程中,所有負標簽都被統一對待,但這種方式把類別間的關系割裂開了。比如說識別手寫數字,同是標簽為“3”的圖片,可能有的比較像“8”,有的比較像“2”,硬標簽區分不出來這個信息,但是一個訓練良好的大模型可以給出。大模型 softmax 層的輸出,除了正例之外,負標簽也帶有大量的信息,比如某些負標簽對應的概率遠遠大于其他負標簽。近似學習這一行為使得每個樣本給學生網絡帶來的信息量大于傳統的訓練方式。
因此,作者在訓練學生網絡時修改了一下損失函數,讓小模型在擬合訓練數據的真值(ground truth)標簽的同時,也要擬合大模型輸出的概率分布。這個方法叫做知識 蒸餾訓練 (Knowledge Distillation Training, KD Training)。知識蒸餾過程所用的訓練樣本可以和訓練大模型用的訓練樣本一樣,或者另找一個獨立的 Transfer set。
方法詳解
具體來說,知識蒸餾使用的是 Teacher—Student 模型,其中 teacher 是“知識”的輸出者,student 是“知識”的接受者。知識蒸餾的過程分為 2 個階段:
- 教師模型訓練:訓練”Teacher 模型“, 簡稱為 Net-T,它的特點是模型相對復雜,也可以由多個分別訓練的模型集成而成。對“Teacher模型”不作任何關于模型架構、參數量、是否集成方面的限制,因為該模型不需要部署,唯一的要求就是,對于輸入 X, 其都能輸出 Y,其中 Y 經過 softmax 的映射,輸出值對應相應類別的概率值。
- 學生模型訓練:訓練“Student 模型”, 簡稱為 Net-S,它是參數量較小、模型結構相對簡單的單模型。同樣的,對于輸入 X,其都能輸出 Y,Y 經過 softmax 映射后同樣能輸出對應相應類別的概率值。
由于使用 softmax 的網絡的結果很容易走向極端,即某一類的置信度超高,其他類的置信度都很低,此時學生模型關注到的正類信息可能還是僅屬于某一類。除此之外,因為不同類別的負類信息也有相對的重要性,所有負類分數都差不多也不好,達不到知識蒸餾的目的。為了解決這個問題,引入溫度(Temperature)的概念,使用高溫將小概率值所攜帶的信息蒸餾出來。具體來說,在 logits 過 softmax 函數前除以溫度 T。
訓練時首先將教師模型學習到的知識蒸餾給小模型,具體來說對樣本 X,大模型的倒數第二層先除以一個溫度 T,然后通過 softmax 預測一個軟目標 Soft target,小模型也一樣,倒數第二層除以同樣的溫度 T,然后通過 softmax 預測一個結果,再把這個結果和軟目標的交叉熵作為訓練的 total loss 的一部分。然后再將小模型正常的輸出和真值標簽(hard target)的交叉熵作為訓練的 total loss 的另一部分。Total loss 把這兩個損失加權合起來作為訓練小模型的最終的 loss。
在小模型訓練好了要預測時,就不需要再有溫度 T 了,直接按照常規的 softmax 輸出就可以了。
03 FitNet
簡介
FitNet 論文在蒸餾時引入了中間層隱藏映射(intermediate-level hints)來指導學生模型的訓練。使用一個寬而淺的教師模型來訓練一個窄而深的學生模型。在進行 hint 引導時,提出使用一個層來匹配 hint 層和 guided 層的輸出 shape,這在后人的工作里面常被稱為 adaptation layer。
總的來說,相當于是在做知識蒸餾時,不僅用到了教師模型的 logit 輸出,還用到了教師模型的中間層特征圖作為監督信息。可以想到的是,直接讓小模型在輸出端模仿大模型,這個對于小模型來說太難了(模型越深越難訓,最后一層的監督信號要傳到前面去還是挺累的),不如在中間加一些監督信號,使得模型在訓練時可以從逐層接受學習更難的映射函數,而不是直接學習最難的映射函數;除此之外,hint 引導加速了學生模型的收斂,在一個非凸問題上找到更好的局部最小值,使得學生網絡能更深的同時,還能訓練得更快。這感覺就好像是,我們的目的是讓學生做高考題,那么就先把初中的題目給他教會了(先讓小模型用前半個模型學會提取圖像底層特征),然后再回到本來的目的、去學高考題(用 KD 調整小模型的全部參數)。
這篇文章是提出蒸餾中間特征圖的始祖,提出的算法很簡單,但思路具有開創性。
方法詳解
FitNets 的具體做法是:
- 確定教師網絡,并訓練成熟,將教師網絡的中間特征層 hint 提取出來。
- 設定學生網絡,該網絡一般較教師網絡更窄、更深。訓練學生網絡使得學生網絡的中間特征層與教師模型的 hint 相匹配。由于學生網絡的中間特征層和與教師 hint 尺寸不同,因此需要在學生網絡中間特征層后添加回歸器用于特征升維,以匹配 hint 層尺寸。其中匹配教師網絡的 hint 層與回歸器轉化后的學生網絡的中間特征層的損失函數為均方差損失函數。
實際訓練的時候往往和上一節的 KD Training 聯合使用,用兩階段法訓練:先用 hint training 去 pretrain 小模型前半部分的參數,再用 KD Training 去訓練全體參數。由于蒸餾過程中使用了更多的監督信息, 基于中間特征圖的蒸餾方法比基于結果 logits 的蒸餾方法效果要好 ,但是訓練時間更久。
04 總結
知識蒸餾對于將知識從集成或從高度正則化的大型模型轉移到較小的模型中非常有效。即使在用于訓練蒸餾模型的遷移數據集中缺少任何一個或多個類的數據時,蒸餾的效果也非常好。在經典之作 KD 和 FitNet 提出之后,各種各樣的蒸餾方法如雨后春筍般涌現。未來我們也希望能在模型壓縮和知識遷移領域做出更進一步的探索。
作者簡介
馬佳良,網易易盾高級計算機視覺算法工程師,主要負責計算機視覺算法在內容安全領域的研發、優化和創新。