「知識蒸餾+持續學習」最新綜述!哈工大、中科院出品:全新分類體系,十大數據集全面實驗
知識蒸餾(Knowledge Distillation, KD)已逐漸成為持續學習(Continual Learning, CL)應對災難性遺忘的常見方法。
然而,盡管KD在減輕遺忘方面取得了一定成果,關于KD在持續學習中的應用及其有效性仍然缺乏深入的探索。
圖1 知識蒸餾在持續學習中的使用
目前,大多數現有的持續學習綜述主要從不同方法的分類角度出發,聚焦于圖像分類領域或其他應用領域,很少有綜述文章專門探討如何通過具體技術(如知識蒸餾)來緩解持續學習中的遺忘問題。
現有的研究大多關注于持續學習方法的廣泛分類(如回放方法、正則化方法、參數隔離方法等),以及任務增量學習、類別增量學習、領域增量學習等不同場景的定義。
盡管這些研究為持續學習領域提供了寶貴的見解,但對于如何將知識蒸餾與持續學習結合并分析其效果,仍然缺乏系統性評估。
最近,哈爾濱工業大學和中科院自動化所的研究人員在IEEE Transactions on Neural Networks and Learning Systems(TNNLS)上發表了一篇綜述論文,聚焦于知識蒸餾在持續學習中的應用。
論文鏈接:https://ieeexplore.ieee.org/document/10721446
主要貢獻為:
- 綜合調查:首次系統地綜述了基于知識蒸餾的持續學習方法,主要集中在圖像分類任務中。研究人員分析了知識蒸餾在持續學習中的應用,提供了詳細的分類,闡述了其在持續學習中的作用與應用場景。
- 新的分類法:提出了一個新的分類體系,將知識蒸餾在持續學習中的應用分為三個主要范式:正則化的知識蒸餾、知識蒸餾與數據回放結合、以及知識蒸餾與特征回放結合。同時,基于蒸餾過程中使用的知識來源,將這些方法分為三個層次:logits級別、特征級別和數據級別,并從蒸餾損失的角度分析了其如何強化記憶。
- 實驗驗證:在CIFAR-100、TinyImageNet和ImageNet-100等數據集上,針對十種基于知識蒸餾的持續學習方法進行了廣泛的實驗,系統地分析了知識蒸餾在持續學習中的作用,驗證了其在減緩遺忘方面的有效性。
- 分類偏差與改進:進一步證實,分類偏差可能會削弱知識蒸餾的效果,而采用separated softmax損失函數結合數據回放時,能夠顯著增強知識蒸餾在減緩遺忘方面的效果。
基于知識蒸餾的持續學習范式
圖2 知識蒸餾在持續學習中的使用范式
正則化的知識蒸餾
正則化的知識蒸餾范式將知識蒸餾作為防止遺忘的核心機制,通過約束模型參數變化來保持舊任務的知識。這些方法的基本思想是通過在新任務學習時,確保模型的輸出盡可能與舊任務模型的輸出一致,從而避免遺忘。
例如,LwF方法通過蒸餾新任務數據在舊模型上的響應,確保新模型在學習新任務時仍能保留對舊任務的記憶[1]。這類方法的理念相對簡單明了,在減輕遺忘方面的表現往往較弱,通常會導致較低的性能。
知識蒸餾與數據回放結合
知識蒸餾經常與數據回放技術相結合,以從數據和模型兩個方面增強記憶保持能力。數據回放方法需要一個額外的緩存區來存儲來自先前任務的樣本,以近似其分布,并在持續學習過程中不斷回放這些樣本,以實現持久的記憶保持。
將知識蒸餾與數據回放結合,進一步增強了模型的記憶保持能力。iCaRL是第一個將知識蒸餾與數據回放相結合的方法[2]。此后,許多結合知識蒸餾和數據回放的方法將數據回放作為應對遺忘的基本技術,并探索各種蒸餾技術以進一步增強舊任務的記憶保持能力。
除了直接使用額外的內存存儲來自舊任務的回放數據外,一些方法還通過生成模型[5][6]或模型反演[7][8][9]技術生成回放數據。
這些方法通常將知識蒸餾應用于生成的數據,以防止生成模型在持續學習過程中遺忘,同時也在logits或特征上使用基本的知識蒸餾技術來減緩遺忘。對于這些方法來說,除了知識蒸餾是減緩遺忘的有效手段外,生成數據的質量也在決定整體效果方面起著至關重要的作用。
由于回放數據與新任務數據之間存在嚴重的數據不平衡,這容易導致分類偏差,一些方法在將知識蒸餾作為記憶保持的基本機制的同時,著重解決分類偏差問題。
例如,BiC顯式地通過在平衡的驗證數據集上訓練類別校正參數來解決分類偏差問題[10]。除了分類偏差問題外,一些其他方法將知識蒸餾與數據回放的結合作為基本的記憶保持手段,并更加關注其他問題,如回放數據的選擇[13][14]以及結合基于架構的方法[15][16][17]來保持記憶。
知識蒸餾與特征回放結合
除了將知識蒸餾與數據回放結合,許多方法還將知識蒸餾與特征回放結合,旨在實現無需示例的持續學習。這一范式中的大多數方法通過在特征級別的蒸餾中使用實例特征對齊,以保持特征網絡的記憶能力,并采用各種特征生成方法來生成回放特征,從而確保分類器的記憶得到保持。
例如,GFR方法通過訓練生成模型來存儲舊任務的特征,該生成模型在持續學習過程中生成回放特征[18]。PASS方法將類別原型定義為特征空間數據的均值,并在新類別學習期間引入高斯噪聲進行數據增強,從而避免分類偏向新數據[19]。
與「知識蒸餾與數據回放結合」范式中的方法相比,這一范式不需要大量額外的內存來存儲舊任務的原始樣本。相反,它只需少量內存來存儲每個類別的特征信息。此外,特征回放有助于減少由于回放數據和新任務數據之間的不平衡所引起的分類偏差問題。
知識來源與蒸餾損失
圖3 按知識來源分類的基于知識蒸餾的持續學習方法
研究人員根據知識來源將基于知識蒸餾的持續學習方法分為三類:logits級別、特征級別和數據級別。
Logits級別蒸餾主要涉及學生模型通過模仿教師模型的最終輸出logits來獲取知識。這些輸出通常包括兩種類型:通過歸一化函數(如softmax)得到的分類概率,以及原始的、未經歸一化的logits。
因此,研究人員將logits級別的KD方法分為兩類:概率匹配和logits匹配。概率匹配較為常見,學生模型旨在通過使用KL散度或交叉熵等損失函數,將教師模型的輸出概率分布與自己的輸出概率分布對齊。
相比之下,logits匹配旨在同步教師和學生模型的pre-softmax logit值,通常采用L1或L2范數等損失函數。logits匹配對蒸餾過程施加了比概率匹配更嚴格的約束。
特征級別蒸餾旨在傳遞網絡特征提取階段生成的內部表示知識。這類方法可以根據特征在網絡中的位置和特征的性質分為三個子類:實例特征對齊、隱層特征對齊和關系對齊。
實例特征對齊主要針對從輸入樣本中提取的特征,這些特征通常被轉換為一維向量。隱層特征對齊則關注特征提取器中間層特征的蒸餾,這些特征保留了與網絡結構相關的空間信息。關系對齊則專注于蒸餾多個實例或原型特征之間在特征空間中的局部或全局關系動態。
數據級別蒸餾可以分為兩種類型:顯式數據對齊和隱式數據對齊。顯式數據對齊涉及通過生成模型產生的合成數據進行蒸餾。與此不同,隱式數據對齊則專注于蒸餾數據中的潛在信息,例如注意力圖或潛在編碼。
圖4顯示了一些logits級別和特征級別蒸餾的示意圖。表1展示了不同范式的持續學習方法使用的知識蒸餾級別以及相應使用的蒸餾損失。
圖4 logits級別與特征級別蒸餾示意圖
表1 基于知識蒸餾的持續學習方法歸納分類
實驗
研究人員選擇了三個在持續學習領域廣泛使用的圖像分類數據集:CIFAR-100、TinyImageNet和ImageNet-100,涵蓋了從32×32、64×64到224×224像素的不同圖像分辨率,實驗聚焦于類別增量學習(CIL)。
研究人員采用了兩種主要策略來模擬數據增量場景:第一種方法將數據集均勻分成多個任務,每個任務包含相等數量的類別,進行持續學習;第二種方法先對一部分類別進行初步的基礎訓練,然后使用剩余類別進行持續學習。
為了清晰描述這些場景,研究人員采用了[22]中的符號表示,選擇了十個基于知識蒸餾的持續學習方法進行實驗:LwF [1]、LwM [3]、IL2A [20]、PASS [19]、PRAKA [21]、iCaRL [2]、EEIL [4]、BiC [10]、LUCIR [11]和SS-IL [12]。
針對數據集的實驗
表2 針對不同數據集的實驗結果
實驗結果如表2所示。針對所有數據集,在沒有基礎訓練的10任務場景中,BiC方法在所有數據集上表現最佳。在有基礎訓練的11任務場景中,PRAKA在所有數據集上表現突出。在有無基礎訓練的兩種場景中,「知識蒸餾與數據回放結合」范式的方法普遍表現較好。
在沒有基礎訓練的場景中,「知識蒸餾與特征回放結合」范式的方法略遜于數據回放范式。然而,在有基礎訓練的場景中,特征回放方法的表現顯著提升,PRAKA在所有數據集上超過了數據回放范式的方法。
相比之下,「正則化的知識蒸餾」范式方法表現較差,且LwF和LwM在有基礎訓練的場景中表現低于沒有基礎訓練的情況,其他方法通常在有基礎訓練的場景中表現更好。
針對知識蒸餾效果的實驗
表3 針對知識蒸餾效果的實驗結果
本實驗通過去除知識蒸餾損失函數,探討了知識蒸餾在持續學習抗遺忘中起的作用,實驗結果如表3所示。所有方法中,除了LwM采用了兩種蒸餾損失外,大多數方法都使用了單一的蒸餾損失。對于LwM,僅去除其注意力圖蒸餾損失,保留了logits級蒸餾。
知識蒸餾在「正則化的知識蒸餾」以及「知識蒸餾與特征回放結合」范式的方法中起到了關鍵作用。在有無基礎訓練的場景中,去除知識蒸餾后,性能明顯下降。
然而,在「知識蒸餾與數據回放結合」范式下的方法中,情況有所不同。結果顯示,在有基礎訓練的場景中,知識蒸餾顯著有助于減緩遺忘,一旦去除蒸餾,所有方法的性能均有所下降。
在沒有基礎訓練的場景中,EEIL、BiC和SS-IL在去除KD后表現下降。相反,iCaRL和LUCIR的性能有所提升,iCaRL的提升尤為明顯,LUCIR的提升較小。
針對蒸餾損失的實驗
表4 針對蒸餾損失的實驗結果
為了評估不同蒸餾損失在減緩遺忘方面的有效性,研究人員進行了獨立的知識蒸餾損失評估,未使用任何其他防止遺忘的技術。
研究人員評估了交叉熵、KL散度、logits級的L2距離損失,以及基于L2距離和余弦相似度的實例特征對齊損失,實驗結果如表4所示。
在持續學習過程中,分類頭的訓練采用了LwF中的方式,即只訓練當前任務的分類頭,而之前任務的分類頭僅參與蒸餾,因為如果沒有來自舊任務的數據使用全局分類損失,會導致嚴重的分類偏差問題,并顯著降低知識蒸餾的效果。
結果表明不同知識蒸餾損失均有減緩遺忘的能力。其中,logits級的知識蒸餾損失在減緩遺忘方面表現明顯優于特征級的知識蒸餾損失。在所有logits級知識蒸餾損失中,L2距離損失具有更強的約束能力,較KL散度表現更好,優于交叉熵蒸餾損失的抗遺忘效果。
對于特征級的知識蒸餾損失,包含更多語義信息的余弦相似度損失,在減緩遺忘方面優于L2距離損失。
針對知識蒸餾與數據回放的實驗
表5 針對知識蒸餾與數據回放的實驗結果
為了進一步了解知識蒸餾在與數據回放結合時的作用,并探索不同知識蒸餾損失的效果,研究人員將幾種知識蒸餾損失與基本的回放范式進行比較,數據回放使用herding方式來緩存回放數據,每個類別保存20個樣本。
實驗結果(表5 -a)表明將知識蒸餾與數據回放結合時,logits級的知識蒸餾損失始終會導致性能下降,這一負面影響在沒有基礎訓練的情況下尤為明顯,logits級知識蒸餾會顯著降低性能。
在基礎訓練的情況下,特征級知識蒸餾的正面效果稍微更明顯,而余弦相似度損失在保持已學習特征方面表現優越。然而,在沒有基礎訓練的情況下,余弦相似度損失在保持記憶方面的效果不如L2損失。
研究人員假設這種現象可能是由于分類頭引入的分類偏差所致。為了驗證這一假設,采用SS-IL的方法中使用Separated softmax損失來學習分類頭,即使用回放數據共同訓練所有舊任務的分類頭,而新任務數據則專門用于訓練新任務的分類頭。實驗結果(表5 -b)表明分類偏差確實會影響KD的效果。
令人驚訝的是,即使沒有使用KD,使用Separated softmax的數據回放也比全局分類的回放表現更好。
未來展望
論文從三個不同的視角探討了基于知識蒸餾的持續學習的未來發展趨勢。
高質量知識的知識蒸餾:盡管知識蒸餾在減緩持續學習中的災難性遺忘方面已經展現出潛力,但仍有較大的提升空間。有效的知識傳遞依賴于蒸餾知識的質量。高質量的知識傳遞對于提升持續學習中的知識蒸餾效果至關重要。
隨著對知識質量的要求越來越高,如何更好地提取和傳遞高質量知識,將是未來持續學習研究中的一個重要方向。
針對特定任務的知識蒸餾:持續學習的研究已從最初專注于分類任務,擴展到包括其他多種任務,例如計算機視覺中的目標檢測、語義分割,以及自然語言處理中的語言學習、機器翻譯、意圖識別和命名實體識別等。
這表明,知識蒸餾不僅能夠應用于傳統的分類任務,還需要針對具體任務進行定制化設計,以提高在不同應用場景中的表現。
更好的教師模型:近年來,基于預訓練模型(PTM)和大型語言模型(LLM)的持續學習受到了越來越多的關注。知識蒸餾作為一種自然適用于減少遺忘的技術,對于PTM和LLM的持續學習尤為重要。
這是因為知識蒸餾遵循教師-學生框架,而PTM和LLM已經具備了豐富的知識,可以作為「具有豐富的經驗教師」開始持續學習,從而更有效地指導學生模型的學習。未來,如何通過更強大的教師模型來優化知識蒸餾的效果,將是持續學習中值得深入研究的方向。
參考資料:
1.Z. Li and D. Hoiem, “Learning without forgetting,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 40, no. 12, pp. 2935–2947, 2017.
2.S.-A. Rebuffi, A. Kolesnikov, G. Sperl, and C. H. Lampert, “icarl: Incremental classifier and representation learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 2001–2010, 2017.
3.P. Dhar, R. V. Singh, K.-C. Peng, Z. Wu, and R. Chellappa, “Learning without memorizing,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 5138–5146, 2019.
4.F. M. Castro, M. J. Mar??n-Jiménez, N. Guil, C. Schmid, and K. Alahari, “End-to-end incremental learning,” in Eur. Conf. Comput. Vis., pp. 233–248, 2018.
5.C. Wu, L. Herranz, X. Liu, J. Van De Weijer, B. Raducanu, et al., “Memory replay gans: Learning to generate new categories without forgetting,” in Adv. Neural Inform. Process. Syst., vol. 31, pp. 5966–5976, 2018.
6.W. Hu, Z. Lin, B. Liu, C. Tao, Z. T. Tao, D. Zhao, J. Ma, and R. Yan, “Overcoming catastrophic forgetting for continual learning via model adaptation,” in Int. Conf. Learn. Represent., 2019.
7.J. Smith, Y.-C. Hsu, J. Balloch, Y. Shen, H. Jin, and Z. Kira, “Always be dreaming: A new approach for data-free class-incremental learning,”in Int. Conf. Comput. Vis., pp. 9374–9384, 2021.
8.Q. Gao, C. Zhao, B. Ghanem, and J. Zhang, “R-dfcil: Relation-guided representation learning for data-free class incremental learning,” in Eur. Conf. Comput. Vis., pp. 423–439, Springer, 2022.
9.M. PourKeshavarzi, G. Zhao, and M. Sabokrou, “Looking back on learned experiences for class/task incremental learning,” in Int. Conf. Learn. Represent., 2021.
10.Y. Wu, Y. Chen, L. Wang, Y. Ye, Z. Liu, Y. Guo, and Y. Fu, “Large scale incremental learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 374–382, 2019.
11.S. Hou, X. Pan, C. C. Loy, Z. Wang, and D. Lin, “Learning a unified classifier incrementally via rebalancing,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 831–839, 2019.
12.H. Ahn, J. Kwak, S. Lim, H. Bang, H. Kim, and T. Moon, “Ss-il: Separated softmax for incremental learning,” in Int. Conf. Comput. Vis., pp. 844–853, 2021.
13.Y. Liu, Y. Su, A.-A. Liu, B. Schiele, and Q. Sun, “Mnemonics training: Multi-class incremental learning without forgetting,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 12245–12254, 2020.
14.R. Tiwari, K. Killamsetty, R. Iyer, and P. Shenoy, “Gcr: Gradient coreset based replay buffer selection for continual learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 99–108, 2022.
15.J. Rajasegaran, M. Hayat, S. H. Khan, F. S. Khan, and L. Shao,“Random path selection for continual learning,” in Adv. Neural Inform. Process. Syst., vol. 32, pp. 12648–12658, 2019.
16.A. Douillard, A. Ramé, G. Couairon, and M. Cord, “Dytox: Transformers for continual learning with dynamic token expansion,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 9285–9295, 2022.
17.F.-Y. Wang, D.-W. Zhou, H.-J. Ye, and D.-C. Zhan, “Foster: Feature boosting and compression for class-incremental learning,” in Eur. Conf. Comput. Vis., pp. 398–414, Springer, 2022.
18.X. Liu, C. Wu, M. Menta, L. Herranz, B. Raducanu, A. D. Bagdanov, S. Jui, and J. v. de Weijer, “Generative feature replay for class-incremental learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 226–227, 2020.
19.F. Zhu, X.-Y. Zhang, C. Wang, F. Yin, and C.-L. Liu, “Prototype augmentation and self-supervision for incremental learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 5871–5880, 2021.
20.F. Zhu, Z. Cheng, X.-Y. Zhang, and C.-l. Liu, “Class-incremental learning via dual augmentation,” in Adv. Neural Inform. Process. Syst., vol. 34, pp. 14306–14318, 2021.
21.W. Shi and M. Ye, “Prototype reminiscence and augmented asymmetric knowledge aggregation for non-exemplar class-incremental learning,”in Int. Conf. Comput. Vis., pp. 1772–1781, 2023.
22.M. Masana, X. Liu, B. Twardowski, M. Menta, A. D. Bagdanov, and J. Van De Weijer, “Class-incremental learning: survey and performance evaluation on image classification,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 45, no. 5, pp. 5513–5533, 2022.