輕量級(jí)持續(xù)學(xué)習(xí): 0.6%額外參數(shù)把舊模型重編程到新任務(wù)
持續(xù)學(xué)習(xí)的目的是模仿人類在連續(xù)任務(wù)中持續(xù)積累知識(shí)的能力,其主要挑戰(zhàn)是在持續(xù)學(xué)習(xí)新任務(wù)后如何保持對(duì)以前所學(xué)任務(wù)的表現(xiàn),即避免災(zāi)難性遺忘(catastrophic forgetting)。持續(xù)學(xué)習(xí)和多任務(wù)學(xué)習(xí)(multi-task learning)的區(qū)別在于:后者在同一時(shí)間可以得到所有任務(wù),模型可以同時(shí)學(xué)習(xí)所有任務(wù);而在持續(xù)學(xué)習(xí)中任務(wù) 一個(gè)一個(gè)出現(xiàn),模型在某一時(shí)刻只能學(xué)習(xí)一個(gè)任務(wù)的知識(shí),并且在學(xué)習(xí)新知識(shí)的過程中避免遺忘舊知識(shí)。
南加州大學(xué)聯(lián)合 Google Research 提出了一種解決持續(xù)學(xué)習(xí)(continual learning)的新方法通道式輕量級(jí)重編碼(Channel-wise Lightweight Reprogramming [CLR]):通過在固定任務(wù)不變的 backbone 中添加可訓(xùn)練的輕量級(jí)模塊,對(duì)每層通道的特征圖進(jìn)行重編程,使得重編程過的特征圖適用于新任務(wù)。這個(gè)可訓(xùn)練的輕量級(jí)模塊僅僅占整個(gè)backbone的0.6%,每個(gè)新任務(wù)都可以有自己的輕量級(jí)模塊,理論上可以持續(xù)學(xué)習(xí)無(wú)窮多新任務(wù)而不會(huì)出現(xiàn)災(zāi)難性遺忘。文已發(fā)表在 ICCV 2023。
- 論文地址: https://arxiv.org/pdf/2307.11386.pdf
- 項(xiàng)目地址: https://github.com/gyhandy/Channel-wise-Lightweight-Reprogramming
- 數(shù)據(jù)集地址: http://ilab.usc.edu/andy/skill102
通常解決持續(xù)學(xué)習(xí)的方法主要分為三大類:基于正則化的方法、動(dòng)態(tài)網(wǎng)絡(luò)方法和重放方法。
- 基于正則化的方法是模型在學(xué)習(xí)新任務(wù)的過程中對(duì)參數(shù)更新添加限制,在學(xué)習(xí)新知識(shí)的同時(shí)鞏固舊知識(shí)。
- 動(dòng)態(tài)網(wǎng)絡(luò)方法是在學(xué)習(xí)新任務(wù)的時(shí)候添加特定任務(wù)參數(shù)并對(duì)舊任務(wù)的權(quán)重進(jìn)行限制。
- 重放方法假設(shè)在學(xué)習(xí)新任務(wù)的時(shí)候可以獲取舊任務(wù)的部分?jǐn)?shù)據(jù),并與新任務(wù)一起訓(xùn)練。
本文提出的CLR方法是一種動(dòng)態(tài)網(wǎng)絡(luò)方法。下圖表示了整個(gè)過程的 pipeline:研究者使用與任務(wù)無(wú)關(guān)的不可變部分作為共享的特定任務(wù)參數(shù),并添加特定任務(wù)參數(shù)對(duì)通道特征進(jìn)行重編碼。與此同時(shí)為了盡可能地減少訓(xùn)練每個(gè)任務(wù)的重編碼參數(shù),研究者只需要調(diào)整模型中內(nèi)核的大小,并學(xué)習(xí)從 backbone 到特定任務(wù)知識(shí)的通道線性映射來(lái)實(shí)現(xiàn)重編碼。在持續(xù)學(xué)習(xí)中,對(duì)于每一個(gè)新任務(wù)都可以訓(xùn)練得到一個(gè)輕量級(jí)模型;這種輕量級(jí)的模型需要訓(xùn)練的參數(shù)很少,即使任務(wù)很多,總共需要訓(xùn)練的參數(shù)相對(duì)于大模型來(lái)說也很小,并且每一個(gè)輕量級(jí)模型都可以達(dá)到很好的效果。
研究動(dòng)機(jī)
持續(xù)學(xué)習(xí)關(guān)注于從數(shù)據(jù)流中學(xué)習(xí)的問題,即通過特定的順序?qū)W習(xí)新任務(wù),不斷擴(kuò)展其已獲得的知識(shí),同時(shí)避免遺忘以前的任務(wù),因此如何避免災(zāi)難性遺忘是持續(xù)學(xué)習(xí)研究的主要問題。研究者從以下三個(gè)方面考慮:
- 重用而不是重學(xué):對(duì)抗重編碼(Adversarial Reprogramming [1])是一種通過擾動(dòng)輸入空間,在不重新學(xué)習(xí)網(wǎng)絡(luò)參數(shù)的情況下,"重編碼" 一個(gè)已經(jīng)訓(xùn)練并凍結(jié)的網(wǎng)絡(luò)來(lái)解決新任務(wù)的方法。研究者借用了 “重編碼” 的思想,在原始模型的參數(shù)空間而不是輸入空間進(jìn)行了更輕量級(jí)但也更強(qiáng)大的重編程。
- 通道式轉(zhuǎn)換可以連接兩個(gè)不同的核:GhostNet [2] 的作者發(fā)現(xiàn)傳統(tǒng)網(wǎng)絡(luò)在訓(xùn)練后會(huì)得到一些相似的特征圖,因此他們提出了一種新型網(wǎng)絡(luò)架構(gòu) GhostNet:通過對(duì)現(xiàn)有特征圖使用相對(duì)廉價(jià)的操作(比如線性變化)生成更多的特征圖,以此來(lái)減小內(nèi)存。受此啟發(fā),本文方法同樣使用線性變換生成特征圖來(lái)增強(qiáng)網(wǎng)絡(luò),這樣就能以相對(duì)低廉的成本為各個(gè)新任務(wù)量身定制。
- 輕量級(jí)參數(shù)可以改變模型分布:BPN [3] 通過在全連接層中增加了有益的擾動(dòng)偏差,使網(wǎng)絡(luò)參數(shù)分布從一個(gè)任務(wù)轉(zhuǎn)移到另一個(gè)任務(wù)。然而 BPN 只能處理全連接層,每個(gè)神經(jīng)元只有一個(gè)標(biāo)量偏置,因此改變網(wǎng)絡(luò)的能力有限。相反研究者為卷積神經(jīng)網(wǎng)絡(luò)(CNN)設(shè)計(jì)了更強(qiáng)大的模式(在卷積核中增加 “重編碼” 參數(shù)),從而在每項(xiàng)新任務(wù)中實(shí)現(xiàn)更好的性能。
方法敘述
通道式輕量級(jí)重編碼首先用一個(gè)固定的 backbone 作為一個(gè)任務(wù)共享的結(jié)構(gòu),這可以是一個(gè)在相對(duì)多樣性的數(shù)據(jù)集(ImageNet-1k, Pascal VOC)上進(jìn)行監(jiān)督學(xué)習(xí)的預(yù)訓(xùn)練模型,也可以是在無(wú)語(yǔ)義標(biāo)簽的代理任務(wù)上學(xué)習(xí)的自監(jiān)督學(xué)習(xí)模型(DINO,SwAV)。不同于其他的持續(xù)學(xué)習(xí)方法(比如 SUPSUP 使用一個(gè)隨機(jī)初始化的固定結(jié)構(gòu),CCLL 和 EFTs 使用第一個(gè)任務(wù)學(xué)習(xí)后的模型作為 backbone),CLR 使用的預(yù)訓(xùn)練模型可以提供多種視覺特征,但這些視覺特征在其他任務(wù)上需要 CLR 層進(jìn)行重編碼。具體來(lái)說,研究者利用通道式線性變化(channel-wise linear transformation)對(duì)原有卷積核產(chǎn)生的特征圖像進(jìn)行重編碼。
圖中展示了 CLR 的結(jié)構(gòu)。CLR 適用于任何卷積神經(jīng)網(wǎng)絡(luò),常見的卷積神經(jīng)網(wǎng)絡(luò)由 Conv 塊(Residual 塊)組成,包括卷積層、歸一化層和激活層。
研究者首先把預(yù)訓(xùn)練的 backbone 固定,然后在每個(gè)固定卷積塊中的卷積層后面加入通道式輕量級(jí)重編程層 (CLR 層)來(lái)對(duì)固定卷積核后的特征圖進(jìn)行通道式線性變化。
給定一張圖片 X,對(duì)于每個(gè)卷積核 ,可以得到通過卷積核的特征圖 X’,其中每個(gè)通道的特征可以表示為
;之后用 2D 卷積核來(lái)對(duì) X’的每個(gè)通道
進(jìn)行線性變化,假設(shè)每個(gè)卷積核
對(duì)應(yīng)的線性變化的卷積核為
,那么可以得到重編碼后的特征圖
。研究者將 CLR 卷積核的初始化為同一變化核(即對(duì)于的 2D 卷積核,只有中間參數(shù)為 1,其余都為 0),因?yàn)檫@樣可以使得最開始訓(xùn)練時(shí)原有固定 backbone 產(chǎn)生的特征和加入 CLR layer 后模型產(chǎn)生的特征相同。同時(shí)為了節(jié)約參數(shù)并防止過擬合,研究者并不會(huì)在的卷積核后面加入 CLR 層,CLR 層只會(huì)作用在的卷積核后。對(duì)于經(jīng)過 CLR 作用的 ResNet50 來(lái)說,增加的可訓(xùn)練參數(shù)相比于固定的 ResNet50 backbone 只占 0.59%。
對(duì)于持續(xù)學(xué)習(xí),加入 CLR 的模型(可訓(xùn)練的 CLR 參數(shù)和不可訓(xùn)練的 backbone)可以依次學(xué)習(xí)每個(gè)任務(wù)。在測(cè)試的時(shí)候,研究者假設(shè)有一個(gè) task oracle 可以告訴模型測(cè)試圖片屬于哪個(gè)任務(wù),之后固定的 backbone 和相對(duì)應(yīng)的任務(wù)專有 CLR 參數(shù)可以進(jìn)行最終預(yù)測(cè)。由于 CLR 具有絕對(duì)參數(shù)隔離的性質(zhì)(每個(gè)任務(wù)對(duì)應(yīng)的 CLR 層參數(shù)都不一樣并且共享的 backbone 不會(huì)變化),因此 CLR 不會(huì)受到任務(wù)數(shù)量的影響。
實(shí)驗(yàn)結(jié)果
數(shù)據(jù)集:研究者使用圖像分類作為主要任務(wù),實(shí)驗(yàn)室收集了 53 個(gè)圖像分類數(shù)據(jù)集,有大約 180 萬(wàn)張圖片和 1584 個(gè)種類。這 53 個(gè)數(shù)據(jù)集包含了 5 個(gè)不同的分類目標(biāo):物體識(shí)別,風(fēng)格分類,場(chǎng)景分類,計(jì)數(shù)和醫(yī)療診斷。
基線:研究者選擇了 13 種基線,大概可以分成 3 個(gè)種類
- 動(dòng)態(tài)網(wǎng)絡(luò):PSP,SupSup,CCLL,Confit,EFTs
- 正則化:EWC,online-EWC,SI,LwF
- 重放:ER,DERPP
還有一些不屬于持續(xù)學(xué)習(xí)的基線,比如 SGD 和 SGD-LL。SGD 學(xué)習(xí)每個(gè)任務(wù)時(shí)對(duì)整個(gè)網(wǎng)絡(luò)進(jìn)行微調(diào);SGD-LL 是一個(gè)變體,它對(duì)所有任務(wù)都使用一個(gè)固定的 backbone 和一個(gè)可學(xué)習(xí)的共享層,其長(zhǎng)度等于所有任務(wù)最大的種類數(shù)量。
實(shí)驗(yàn)一:第一個(gè)任務(wù)的準(zhǔn)確率
為了評(píng)估所有方法在克服災(zāi)難性遺忘的能力,研究者跟蹤了學(xué)習(xí)新任務(wù)后每個(gè)任務(wù)的準(zhǔn)確性。如果某個(gè)方法存在災(zāi)難性遺忘,那么在學(xué)習(xí)新任務(wù)后,同一任務(wù)的準(zhǔn)確率就會(huì)很快下降。一個(gè)好的持續(xù)學(xué)習(xí)算法可以在學(xué)習(xí)新任務(wù)后保持原有的表現(xiàn),這就意味著舊任務(wù)應(yīng)受到新任務(wù)的影響最小。下圖展示了本文方法從學(xué)完第 1 到第 53 個(gè)任務(wù)后第 1 個(gè)任務(wù)的準(zhǔn)確率。總體而言,本文方法可以保持最高的準(zhǔn)確率。更重要的是它很好地避免了災(zāi)難性遺忘并保持和原始訓(xùn)練方式得到的相同準(zhǔn)確率無(wú)論持續(xù)學(xué)習(xí)多少個(gè)任務(wù)。
實(shí)驗(yàn)二:學(xué)習(xí)所有任務(wù)后的平均準(zhǔn)確率
下圖所有方法在學(xué)完全部任務(wù)后的平均準(zhǔn)確率。平均準(zhǔn)確率反映了持續(xù)學(xué)習(xí)方法的整體表現(xiàn)。由于每個(gè)任務(wù)的難易程度不同,當(dāng)增加一項(xiàng)新任務(wù)時(shí),所有任務(wù)的平均精確度可能會(huì)上升或下降,這取決于增加的任務(wù)是簡(jiǎn)單還是困難。
分析一:參數(shù)和計(jì)算成本
對(duì)于持續(xù)學(xué)習(xí),雖然獲得更高的平均準(zhǔn)確率非常重要,但是一個(gè)好的算法也希望可以最大限度地減少對(duì)額外網(wǎng)絡(luò)參數(shù)的要求和計(jì)算成本。"添加一項(xiàng)新任務(wù)的額外參數(shù)" 表示與原始 backbone 參數(shù)量的百分比。本文以 SGD 的計(jì)算成本為單位,其他方法的計(jì)算成本按 SGD 的成本進(jìn)行歸一化處理。
分析二:不同 backbone 的影響
本文方法通過在相對(duì)多樣化的數(shù)據(jù)集上使用監(jiān)督學(xué)習(xí)或自監(jiān)督學(xué)習(xí)的方法來(lái)訓(xùn)練得到預(yù)訓(xùn)練模型,從而作為與任務(wù)無(wú)關(guān)的不變參數(shù)。為了探究不同預(yù)訓(xùn)練方法的影響,本文選擇了四種不同的、與任務(wù)無(wú)關(guān)的、使用不同數(shù)據(jù)集和任務(wù)訓(xùn)練出來(lái)的預(yù)訓(xùn)練模型。對(duì)于監(jiān)督學(xué)習(xí),研究者使用了在 ImageNet-1k 和 Pascal-VOC 在圖像分類上的預(yù)訓(xùn)練模型;對(duì)于自監(jiān)督學(xué)習(xí),研究者使用了 DINO 和 SwAV 兩種不同方法得到的預(yù)訓(xùn)練模型。下表展示了使用四種不同方法得到預(yù)訓(xùn)練模型的平均準(zhǔn)確率,可以看出來(lái)無(wú)論哪種方法最后的結(jié)果都很高(注:Pascal-VOC 是一個(gè)比較小的數(shù)據(jù)集,所以準(zhǔn)確率相對(duì)低一點(diǎn)),并且對(duì)不同的預(yù)訓(xùn)練 backbone 具有穩(wěn)健性。