清華AIR ModelMerging:無需訓練數據!合并多個模型實現任意場景的感知(ECCV'24)
近日,來自清華大學智能產業研究院(AIR)助理教授趙昊老師的團隊,聯合戴姆勒公司,提出了一種無需訓練的多域感知模型融合新方法。研究重點關注場景理解模型的多目標域自適應,并提出了一個挑戰性的問題:如何在無需訓練數據的條件下,合并在不同域上獨立訓練的模型實現跨領域的感知能力?團隊給出了“Merging Parameters + Merging Buffers”的解決方案,這一方法簡單有效,在無須訪問訓練數據的條件下,能夠實現與多目標域數據混合訓練相當的結果。
論文題目:
Training-Free Model Merging for Multi-target Domain Adaptation
作者:Wenyi Li, Huan-ang Gao, Mingju Gao, Beiwen Tian, Rong Zhi, Hao Zhao
1 背景介紹
一個適用于世界各地自動駕駛場景的感知模型,需要能夠在各個領域(比如不同時間、天氣和城市)中都輸出可靠的結果。然而,典型的監督學習方法嚴重依賴于需要大量人力標注的像素級注釋,這嚴重阻礙了這些場景的可擴展性。因此,多目標域自適應(Multi-target Domain Adaptation, MTDA)的研究變得越來越重要。多目標域自適應通過設計某種策略,在訓練期間同時利用來自多個目標域的無標簽數據以及源域的有標簽合成數據,來增強這些模型在不同目標域上的魯棒性。
與傳統的單目標域自適應 (Single-target Domain Adaptation, STDA)相比,MTDA 面臨更大的挑戰——一個模型需要在多個目標域中都能很好工作。為了解決這個問題,以前的方法采用了各種專家模型之間的一致性學習和在線知識蒸餾來構建各目標域通用的學生模型。盡管如此,這些方法的一個重大限制是它們需要同時使用所有目標數據,如圖1(b) 所示。
但是,同時訪問到所有目標數據是不切實際的。一方面原因是數據傳輸成本限制,因為包含數千張圖像的數據集可能會達到數百 GB。另一方面,從數據隱私保護的角度出發,不同地域間自動駕駛街景數據的共享或傳輸可能會受到限制。面對這些挑戰,在本文中,我們聚焦于一個全新的問題,如圖1(c) 所示。我們的研究任務仍然是MTDA,但我們并沒有來自多個目標域的數據,而是只能獲得各自獨立訓練的模型。我們的目標是,通過某種融合方式,將這些模型集成為一個能夠適用于各個目標域的模型。
圖1:不同實驗設置的對比
2 方法
如何將多個模型合并為一個,同時保留它們在各自領域的能力?我們提出的解決方案主要包括兩部分:Merging Parameters(即可學習層的weight和bias)和 Merging Buffers(即normalization layers的參數)。在第一階段,我們從針對不同單目標域的無監督域自適應模型中,得到訓練后的感知模型。然后,在第二階段,利用我們提出的方法,在無須獲取任何訓練數據的條件下,只對模型做合并,得到一個在多目標域都能工作的感知模型。
圖2:整體實驗流程
下面,我們將詳細介紹這兩種合并的技術細節和研究動機。
2.1 Merging Parameters
2.1.1 Permutation-based的方法出現退化
事實上,如何將模型之間可學習層的 weight 和 bias 合并一直是一個前沿研究領域。在之前的工作中,有一種稱為基于置換 (Permutation-based) 的方法。這些方法基于這樣的假設:當考慮神經網絡隱藏層的所有潛在排列對稱性時,loss landscape 通常形成單個盆地(single basin)。因此,在合并模型參數 和時,這類方法的主要目標是找到一組置換變換 ,確保 在功能上等同于 ,同時也位于參考模型 附近的近似凸盆地(convex basin)內。之后,通過簡單的中點合并 以獲得一個合并后的模型 ,該模型能夠表現出比單個模型更好的泛化能力,
在我們的實驗中,模型 和 在第一階段都使用相同的網絡架構進行訓練,并且,源數據都使用相同的合成圖像和標簽。我們最初嘗試采用了一種 Permutation-based 的代表性方法——Git Re-Basin,該方法將尋找置換對稱變換的問題轉化為線性分配問題 (LAP),是目前最高效實用的算法。
圖3:Git Re-basin和mid-point的實驗結果對比
但是,如圖3所示,我們的實驗結果出乎意料地表明,不同網絡架構(ResNet50、ResNet101 和 MiT-B5)下 Git Re-Basin 的性能與簡單中點合并相同。進一步的研究表明,Git Re-Basin 發現的排列變換在解決 LAP 的迭代中保持相同的排列,這表明在我們的領域適應場景下,Git Re-Basin 退化為一種簡單的中點合并方法。
2.1.2 線性模式連通性的分析
我們從線性模式連通性(linear mode connectivity)的視角進一步研究上述退化問題。具體來說,我們使用連續曲線 在參數空間中連接模型 和模型 。在這種特定情況下,我們考慮如下線性路徑,
接下來,我們通過對 做插值遍歷評估模型的性能。為了衡量這些模型在兩個指定目標域(分別表示為 和 )上的有效性,我們使用調和平均值 (Harmonic Mean) 作為主要評估指標,
我們之所以選擇調和平均值作為指標,是因為它能夠賦予較小的值更大的權重,這能夠更好應對世界各地各個城市中最差的情況。它有效地懲罰了模型在一個目標域(例如,在發達的大城市)的表現異常高,而其他目標域(例如,在第三世界鄉村)表現低的情況。不同插值的實驗結果如圖4(a)所示。“CS”和“IDD”分別表示目標數據集 Cityscapes 和 Indian Driving Dataset。
圖4:線性模式連通性的分析實驗
2.1.3 理解線性模式連通性的原因
在上述實驗結果的基礎上,我們進一步探究:在先前域自適應方法中觀察到的線性模式連通性,背后的根本原因是什么?為此,我們進行了消融實驗,來研究第一階段訓練 和 期間的幾個影響因素。
- 合成數據。使用相同的合成數據可以作為兩個域之間的橋梁。為了評估這一點,我們將合成數據集 GTA 中的訓練數據劃分為兩個不同的非重疊子集,每個子子集包含原始訓練樣本的 30%。在劃分過程中,我們將合成數據集提供的具有相同場景標識的圖像分組到同一個子集中,而具有顯著差異的場景則放在單獨的子集中。我們使用這兩個不同子集分別作為源域,訓練兩個單目標域自適應模型(目標域為 CityScapes 數據集)。隨后,我們研究這兩個 STDA 模型的線性模式連通性。結果如圖 4(b) 所示,可以觀察到,在參數空間內連接兩個模型的線性曲線上,性能沒有明顯下降。這一觀察結果表明,使用相同的合成數據并不是影響線性模式連通性的主要因素。
- 自訓練架構。使用教師-學生模型可能會將最后的模型限制在 loss landscape 的同一 basin 中。為了評估這種可能性,我們禁用了教師模型的指數移動平均 (EMA) 更新。相應地,我們在每次迭代中將學生權重直接復制到教師模型中。隨后,我們繼續訓練兩個單目標域自適應模型,分別利用 GTA 作為源域,Cityscapes 和 IDD 作為目標域。然后,我們研究在參數空間內連接兩個模型的線性曲線,結果如圖 4(c) 所示。我們可以看到線性模式連接屬性保持不變。
- 初始化和預訓練。 使用相同的預訓練權重初始化 backbone 的做法,可能會使模型在訓練過程中難以擺脫的某一 basin。為了驗證這種潛在情況,我們初始化兩個具有不同權重的獨立 backbone,然后繼續針對 Cityscapes 和 IDD 進行域自適應。在評估兩個收斂模型之間的線性插值模型時,我們觀察到性能明顯下降,如圖 4(d) 所示。為了更深入地了解潛在因素,我們繼續探究,是相同的初始權重,還是預訓練過程導致了這種影響? 我們初始化兩個具有相同權重但沒有預訓練的主干,然后再次進行實驗。有趣的是,我們發現,在參數空間的線性連接曲線仍然遇到了巨大的性能障礙,如圖 4(e) 所示。這意味著預訓練過程在模型中的線性模式連接方面起著關鍵作用。
2.1.4 關于合并參數的小結
我們通過大量實驗證明,當領域自適應模型從相同的預訓練權重開始時,模型可以有效地過渡到不同的目標領域,同時仍然保持參數空間中的線性模式連通性。因此,這些訓練模型可以通過簡單的中點合并,得到在兩個領域都有效的合并模型。
2.2 Merging Buffers
Buffers,即批量歸一化 (BN) 層的均值和方差,與數據域密切相關。因為數據不同的方差和均值代表了域的某些特定特征。在合并模型時如何有效地合并 Buffers 的問題通常被忽視,因為現有方法主要探究如何合并在同一域內的不同子集上訓練的兩個模型。在這樣的前提下,之前的合并方法不考慮 Buffers 是合理的,因為來自任何給定模型的 Buffers 都可以被視為對整個總體的無偏估計,盡管它完全來自隨機數據子樣本。
但是,在我們的實驗環境中,我們正在研究如何合并在完全不同的目標域中訓練的兩個模型,這使得 Buffers 合并的問題不再簡單。由于我們假設在模型 A 和模型 B 的合并階段無法訪問任何形式的訓練數據,因此我們可用的信息僅限于 Buffers 集 。其中, 表示 BN 層的數量,而 、 和 分別表示第 層的平均值、標準差和 tracked 的批次數。生成 BN 層的統計數據如下:
以上方程背后的原理可以解釋如下:引入 BN 層是為了緩解內部協變量偏移(internal covariate shift)問題,其中輸入的均值和方差在通過內部可學習層時會發生變化。在這種情況下,我們的基本假設是,后續可學習層合并的 BN 層的輸出遵循正態分布。由于生成的 BN 層保持符合高斯先驗的輸入歸納偏差,我們根據從 和 得到的結果估計 和 。如圖5所示,我們獲得了從該高斯先驗中采樣的兩組數據點的均值和方差,以及這些集合的大小。我們利用這些值來估計該分布的參數。
圖5:合并BN層的示意圖
當將 Merging Buffers 方法擴展到 個高斯分布時,tracked 的批次數 、均值的加權平均值 和方差的加權平均值可以按如下方式計算。
3 實驗與結果
3.1 數據集
在多目標域適應實驗中,我們使用 GTA 和 SYNTHIA 作為合成數據集,并使用 Cityscapes 、Indian Driving Dataset 、ACDC 和 DarkZurich 的作為目標域真實數據集。在訓練單個領域自適應模型時,使用帶有標記的源域數據和無標記的目標域數據。接下來,我們采用所提出的模型融合技術,直接從訓練好的模型出發構建混合模型,這個過程中無需使用訓練數據。
3.2 與Baseline模型的比較
在實驗中,我們將我們的模型融合方法在 MTDA 任務上的結果與幾種 baseline 模型進行對比。baseline 模型包括數據組合(Data Comb.)方法,其中單個域自適應模型在來自兩個目標域的混合數據上進行訓練(這個baseline僅供參考,因為它們與我們關于數據傳輸帶寬和數據隱私問題的設定相矛盾)。baseline 模型還包括單目標域自適應(STDA),即為單一目標域訓練的自適應模型,評估其在兩個域上的泛化能力。
表1:與Baseline模型的比較
表 1 展示了基于 CNN 架構的 ResNet101和基于 Transformer 架構的 MiT-B5 的結果。與最好的單目標域自適應模型相比,當將我們的方法分別應用于 ResNet101 和 MiT-B5 兩種不同 Backbone 時,在兩個目標域上性能的調和平均值分別提高 +4.2% 和 +1.2%。值得注意的是,這種性能水平(ResNet101架構下的調和平均值為 56.3%)已經與數據組合(Data Comb.)方法(56.2%)相當,而且我們無需訪問任何訓練數據即可實現這一目標。
此外,我們探索了一種更為寬松的條件,其中僅合并 Encoder backbone,而 decoder head 則針對各個下游域進行分離。值得注意的是,這種條件下,分別使兩種 backbone 下的調和平均性能顯著提高 +5.6% 和 +2.5%。我們還發現,我們提出的方法在大多數類別中能夠始終實現最佳調和平均,這表明它能夠增強全局適應性,而不是偏向某些類別。
3.3 與SoTA模型的比較
我們首先將我們的方法與 GTACityscapes 任務上的單目標域自適應 (STDA) 進行比較,如表 2 所示。值得注意的是,我們的方法可以應用于任何這些方法,只要它們使用相同的預訓練權重適應不同的域。這使我們能夠使用單個模型推廣到所有目標域,同時保持 STDA 方法相對優越的性能。
表2:與SoTA模型的比較
我們還將我們的方法與表 2 中的域泛化(DG)方法進行了比較,域泛化旨在將在源域上訓練的模型推廣到多個看不見的目標域。我們的方法無需額外的技巧,只需利用參數空間的線性模式連接即可實現卓越的性能。在多目標域自適應領域,我們的方法也取得了領先。我們不需要對多個學生模型做顯式的域間一致性正則化或知識提煉,但能使 STDA 方法中的技術(如多分辨率訓練)能夠輕松轉移到 MTDA 任務。可以觀察到,我們對 MTDA 任務的最佳結果做出了的顯著改進,同時消除了對訓練數據的依賴。
3.4 多目標域拓展
我們還擴展了我們的模型融合技術,以涵蓋四個不同的目標領域:Cityscapes 、IDD 、ACDC 和 DarkZurich 。每個領域都面臨著獨特的挑戰和特點:Cityscapes 主要關注歐洲城市環境,IDD 主要體現印度道路場景,ACDC 主要針對霧、雨或雪等惡劣天氣條件,DarkZurich 則主要處理夜間道路場景。我們對針對每個領域單獨訓練后的模型,以及用我們的方法融合后的模型進行了全面評估。
表3:在4個目標域上的實驗結果
如表 3 所示,我們提出的模型融合技術表現出顯著的性能提升。雖然我們將來自單獨訓練模型的調和平均值最高的方法作為比較的基線,但所有基于模型融合的方法都優于它,性能增長高達 +5.8%。此外,盡管合并來自多個不同領域模型的復雜性不斷增加,但我們觀察到所有領域的整體性能并沒有明顯下降。通過進一步分析,我們發現我們的方法能夠簡化領域一致性的復雜性。現有的域間一致性正則化和在線知識提煉方法的復雜度為 ,而我們的方法可以將其減少到更高效的 ,其中 表示考慮的目標域數量。
3.5 消融實驗
我們使用 ResNet101 和 MiT-B5 作為分割網絡中的圖像編碼器,對我們提出的 Merging Parameters 和 Merging Buffers 方法進行了消融研究,結果如表 4 所示。我們觀察到單目標域自適應 (STDA) 模型在不同域中的泛化能力存在差異,這主要源于所用目標數據集的多樣性和質量差異。盡管如此,我們還是選擇 STDA 模型中的最高的調和平均值作為比較基線。
表4:消融實驗
表 4(a) 和 4(b) 中的數據顯示,采用簡單的中點合并方法對參數進行處理,可使模型的泛化能力提高 +2.7% 和 +0.6%。此外,當結合 Merging Buffers 時,這種性能的增強會進一步放大到 +4.2% 和+1.2%。我們還觀察到 MiT-B5 作為 backbone 時的一個有趣現象:在 IDD 域中進行評估時,融合模型的表現優于單目標自適應模型。這一發現意味著模型可以從其他域獲取域不變的知識。這些結果表明,我們提出的模型融合技術的每個部分都是有效的。
3.6 模型融合在分類任務上的應用
我們還通過實驗驗證了我們所提出的模型融合方法在圖像分類任務上的有效性。通過將 CIFAR-100 分類數據集劃分為兩個不同的、不重疊的子集,我們在這些子集上獨立訓練兩個 ResNet50 模型,標記為 A 和 B。這種訓練要么從一組共同的預訓練權重中進行,要么從兩組隨機初始化的權重中進行。模型 A 和 B 的性能結果如圖 6 所示。結果表明,從相同的預訓練權重進行融合的模型優于在任何單個子集上訓練的模型。相反,當從隨機初始化的權重開始時,單個模型表現出學習能力,而合并模型的性能類似于隨機猜測。
圖6:CIFAR-100 分類任務上的模型融合結果
隨機初始化會破壞模型線性平均性,而相同的預訓練主干會導致線性模式連接。我們在另一個預訓練權重上再次驗證了這個結論。圖 7 中的結果表明,DINO 預訓練和 ImageNet 預訓練在模型參數空間中具有不同的loss landscape,模型的融合必須在相同的loss landscape內進行。
圖7:ImageNet和DINO預訓練權重對線性模式連接的影響
4 結論
本文介紹了一種新穎的模型融合策略,旨在解決多目標域自適應 (MTDA)問題,同時無需依賴訓練數據。研究結果表明,在大量數據集上進行預訓練時,基于 CNN 的神經網絡和基于 Transformer 的視覺模型都可以將微調后模型限制在 loss landscape 的相同 basin 中。我們還強調了 Buffers 的合并在 MTDA 中的重要性,因為 Buffers 是捕獲各個域獨特特征的關鍵。我們所提出的模型融合方法簡單而高效,在 MTDA 基準上取得了最好的評測性能。我們期待本文所提出的模型融合方法能夠激發未來更多關于這個領域的探索。