ICLR 2025 Spotlight | 慕尼黑工業大學&北京大學:邁向無沖突訓練的ConFIG方法
本文由慕尼黑工業大學與北京大學聯合團隊撰寫。第一作者劉強為慕尼黑工業大學博士生。第二作者楚夢渝為北京大學助理教授,專注于物理增強的深度學習算法,以提升數值模擬的靈活性及模型的準確性和泛化性。通訊作者 Nils Thuerey 教授(慕尼黑工業大學)長期研究深度學習與物理模擬,尤其是流體動力學模擬的結合,并曾因高效流動特效模擬技術獲奧斯卡技術獎。目前,其團隊重點關注可微物理模擬及物理應用中的先進生成式模型。
在深度學習的多個應用場景中,聯合優化多個損失項是一個普遍的問題。典型的例子包括物理信息神經網絡(Physics-Informed Neural Networks, PINNs)、多任務學習(Multi-Task Learning, MTL)和連續學習(Continual Learning, CL)。然而,不同損失項的梯度方向往往相互沖突,導致優化過程陷入局部最優甚至訓練失敗。
目前,主流的方法通常通過調整損失權重來緩解沖突。例如在物理信息神經網絡中,許多研究從數值剛度、損失的收斂速度差異和神經網絡的初始化角度提出了許多權重方法。然而,盡管這些方法聲稱具有更高的解的精度,但目前對于最優的加權策略尚無共識。
針對這一問題,來自慕尼黑工業大學和北京大學的聯合研究團隊提出了 ConFIG(Conflict-Free Inverse Gradients,無沖突逆梯度)方法,為多損失項優化提供了一種穩定、高效的優化策略。ConFIG 提供了一種優化梯度,能夠防止由于沖突導致優化陷入某個特定損失項的局部最小值。ConFIG 方法可以在數學上證明其收斂特性并具有以下特點:
- 最終更新梯度
與所有損失項的優化梯度均不沖突。
在每個特定損失梯度上的投影長度是均勻的,可以確保所有損失項以相同速率進行優化。
長度可以根據損失項之間的沖突程度自適應調整。
此外,ConFIG 方法還引入了一種基于動量的變種。通過計算并緩存每個損失項梯度的動量,可以避免在每次訓練迭代中計算所有損失項的梯度。結果表明,基于動量的 ConFIG 方法在顯著降低訓練成本的同時保證了優化的精度。
想深入了解 ConFIG 的技術細節?我們已經為你準備好了完整的論文、項目主頁和代碼倉庫!
- 論文地址:https://arxiv.org/abs/2408.11104
- 項目主頁:https://tum-pbs.github.io/ConFIG/
- GitHub: https://github.com/tum-pbs/ConFIG
ConFIG: 無沖突逆梯度方法
目標:給定個損失函數
,其對應梯度為
。我們希望找到一個優化方向
,使其滿足:
。即所有損失項在該方向上都能減少,從而避免梯度沖突。
無沖突優化區間
假設存在一個無沖突更新梯度,我們可以引入一個新的矢量。由于
是一個無沖突梯度,
應為一個正向分量矢量。同樣地,我們也可以預先定義一個正向分量矢量
,然后直接通過矩陣的逆運算求得無沖突更新梯度
,即
。通過給定不同的正向分量矢量
,我們得到由一系列不同
組成的無沖突優化區間。
確定唯一優化梯度
盡管通過簡單求逆可以獲得一個無沖突更新區間,我們需要進一步確定唯一的無沖突梯度用于優化。在 ConFIG 方法中,我們從方向和幅度兩個方面進一步限定了最終用于優化更新的梯度:
- 具體優化方向:相比于直接求解梯度矩陣的逆,ConFIG 方法求解了歸一化梯度矩陣的逆,即
,其中
表示第
個梯度向量的單位向量。可以證明,變換后
矢量的每個分量代表了每個梯度
與最終更新梯度
之間的余弦相似度。因此,通過設定
分量的不同值可以直接控制最終更新梯度對于每個損失梯度的優化速率。在 ConFIG 中,
被設定為單位矢量以確保每個損失具有相同的優化強度從而避免某些損失項的優化被忽略。
- 優化梯度大小:此外,ConFIG 方法還根據梯度沖突程度調整步長。當梯度方向較一致時,加快更新;當梯度沖突嚴重時,減小更新幅度:
, 其中
為每個梯度與最終更新方向之間的余弦相似度。
ConFIG 方法獲得最終無沖突優化方向的計算過程可以總結為:
原論文中給出了上述 ConFIG 更新收斂性的嚴格證明。同時,我們還可以證明只要參數空間的維度大于損失項的個數,ConFIG 運算中的逆運算總是可行的。
M-ConFIG: 結合動量加速訓練
ConFIG 方法引入了矩陣的逆運算,這將帶來額外的計算成本。然而與計算每個損失的梯度帶來的計算成本,其并不顯著。在包括 ConFIG 在內的基于梯度的方法中,總是需要額外的反向傳播步驟獲得每個梯度相對于訓練參數的梯度。這使得基于梯度的方法的計算成本顯著高于標準優化過程和基于權重的方法。為此,我們引入了 M-ConFIG 方法,使用動量加速優化:
- 使用梯度的動量(指數移動平均)代替梯度進行 ConFIG 運算。
- 在每次優化迭代中,僅對一個或部分損失進行反向傳播以更新動量。其它損失項的動量采用之前迭代步的歷史值。
在實際應用中,M-ConFIG 的計算成本往往低于標準更新過程或基于權重的方法。這是由于反向傳播一個子損失往往要比反向傳播總損失
更快。這在物理信息神經網絡中尤為明顯,因為邊界上的采樣點通常遠少于計算域內的采樣點。在我們的實際測試中,M-ConFIG 的平均計算成本為基于權重方法的 0.56 倍。
結果:更快的收斂,更優的預測
物理信息神經網絡
在物理信息神經網絡中,用神經網絡的自動微分來近似偏微分方程的時空間導數。偏微分方程的殘差項與邊界條件和初始條件被視作不同的損失項在訓練過程中進行聯合優化。我們在多個經典的物理神經信息網絡中測試了 ConFIG 方法的表現。
結果顯示,在相同訓練迭代次數下,ConFIG 方法是唯一一個相比于標準 Adam 方法始終獲得正向提升的方法。對每個損失項變化的單獨分析表明,ConFIG 方法在略微提高 PDE 訓練殘差的同時大幅降低了邊界和初始條件損失
,實現了 PDE 訓練精度的整體提升。
相同迭代步數下不同方法在 PINNs 測試中相比于 Adam 優化器的相對性能提升
不同損失項隨著訓練周期的變化情況
在實際應用中,相同訓練時間下的模型準確性可能更為重要。M-ConFIG 方法通過使用動量近似梯度帶來的運算速度提升可以使其充分發揮潛力。在相同訓練時間內,M-ConFIG 方法的測試結果優于其他所有方法,甚至高于常規的 ConFIG 方法。
此外,我們還在最具有挑戰性的三維 Beltrami 流動中進一步延長訓練時間來更加深入地了解 M-ConFIG 方法的性能。結果表明,M-ConFIG 方法并非僅在優化初始階段帶來顯著的性能改善,而是在整個優化過程中都持續改善優化的過程。
相同訓練時間下不同方法在 PINNs 測試中相比于 Adam 優化器的相對性能提升
三維 Beltrami 流動案例中預測誤差隨著訓練時間的變化
多任務學習
我們還測試了 ConFIG 方法在多任務學習(MTL)方面的表現。我們采用經典的 CelebA 數據集,其包含 20 萬張人臉圖像并標注了 40 種不同的面部二元屬性。對每張人像面部屬性的學習是一個非常有挑戰的 40 項損失的多任務學習。
實驗結果表明,ConFIG 方法或 M-ConFIG 方法在平均 F1 分數、平均排名
中均表現最佳。其中,對于 M-ConFIG 方法,我們在一次迭代中更新 30 個動量而不僅更新一個動量。這是因為當任務數量增加時,單個動量更新時間的間隔較長,歷史動量信息難以準確捕捉梯度的變化。動量信息的滯后會逐漸抵消 M-ConFIG 方法更高訓練效率帶來的性能提升。
在我們的測試中,當任務數量等于 10 時,M-ConFIG 方法在相同訓練時間下的性能就已經弱于 ConFIG 方法。增加單次迭代過程中的動量更新次數可以顯著緩解這種性能下降。在標準的 40 任務 CelebA 訓練中將動量更新次數提升到 20 時,M-ConFIG 方法的性能已經接近 ConFIG 方法,而訓練時間僅為 ConFIG 方法的 56%。當更新步數達到 30 時,其性能甚至可以優于 ConFIG 方法。
ConFIG 方法在 CelebA 人臉屬性數據集中的表現
結論
在本研究中,我們提出了 ConFIG 方法來解決不同損失項之間的訓練沖突。ConFIG 方法通過確保最終更新梯度與每個子梯度之間的正點積來確保無沖突學習。此外,我們還發展了一種基于動量的方法,用交替更新的動量代替梯度,顯著提升了訓練效率。ConFIG 方法有望為眾多包含多個損失項的深度學習任務帶來巨大的性能提升。