一階優(yōu)化算法啟發(fā),北大林宙辰團(tuán)隊(duì)提出具有萬有逼近性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu)的設(shè)計(jì)方法
以神經(jīng)網(wǎng)絡(luò)為基礎(chǔ)的深度學(xué)習(xí)技術(shù)已經(jīng)在諸多應(yīng)用領(lǐng)域取得了有效成果。在實(shí)踐中,網(wǎng)絡(luò)架構(gòu)可以顯著影響學(xué)習(xí)效率,一個(gè)好的神經(jīng)網(wǎng)絡(luò)架構(gòu)能夠融入問題的先驗(yàn)知識(shí),穩(wěn)定網(wǎng)絡(luò)訓(xùn)練,提高計(jì)算效率。目前,經(jīng)典的網(wǎng)絡(luò)架構(gòu)設(shè)計(jì)方法包括人工設(shè)計(jì)、神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索(NAS)[1]、以及基于優(yōu)化的網(wǎng)絡(luò)設(shè)計(jì)方法 [2]。人工設(shè)計(jì)的網(wǎng)絡(luò)架構(gòu)如 ResNet 等;神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索則通過搜索或強(qiáng)化學(xué)習(xí)的方式在搜索空間中尋找最佳網(wǎng)絡(luò)結(jié)構(gòu);基于優(yōu)化的設(shè)計(jì)方法中的一種主流范式是算法展開(algorithm unrolling),該方法通常在有顯式目標(biāo)函數(shù)的情況下,從優(yōu)化算法的角度設(shè)計(jì)網(wǎng)絡(luò)結(jié)構(gòu)。
然而,現(xiàn)有經(jīng)典神經(jīng)網(wǎng)絡(luò)架構(gòu)設(shè)計(jì)大多忽略了網(wǎng)絡(luò)的萬有逼近性質(zhì) —— 這是神經(jīng)網(wǎng)絡(luò)具備強(qiáng)大性能的關(guān)鍵因素之一。因此,這些設(shè)計(jì)方法在一定程度上失去了網(wǎng)絡(luò)的先驗(yàn)性能保障。盡管兩層神經(jīng)網(wǎng)絡(luò)在寬度趨于無窮的時(shí)候就已具有萬有逼近性質(zhì) [3],在實(shí)際中,我們通常只能考慮有限寬的網(wǎng)絡(luò)結(jié)構(gòu),而這方面的表示分析的結(jié)果十分有限。實(shí)際上,無論是啟發(fā)性的人工設(shè)計(jì),還是黑箱性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索,都很難在網(wǎng)絡(luò)設(shè)計(jì)中考慮萬有逼近性質(zhì)?;趦?yōu)化的神經(jīng)網(wǎng)絡(luò)設(shè)計(jì)雖然相對(duì)更具解釋性,但其通常需要一個(gè)顯式的目標(biāo)函數(shù),這使得設(shè)計(jì)的網(wǎng)絡(luò)結(jié)構(gòu)種類有限,限制了其應(yīng)用范圍。如何系統(tǒng)性地設(shè)計(jì)具有萬有逼近性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu),仍是一個(gè)重要的問題。
為了解決這個(gè)問題,北京大學(xué)林宙辰教授團(tuán)隊(duì)提出了一種易于操作的基于優(yōu)化算法設(shè)計(jì)具有萬有逼近性質(zhì)保障的神經(jīng)網(wǎng)絡(luò)架構(gòu)的方法,其通過將基于梯度的一階優(yōu)化算法的梯度項(xiàng)映射為具有一定性質(zhì)的神經(jīng)網(wǎng)絡(luò)模塊,再根據(jù)實(shí)際應(yīng)用問題對(duì)模塊結(jié)構(gòu)進(jìn)行調(diào)整,就可以系統(tǒng)性地設(shè)計(jì)具有萬有逼近性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu),并且可以與現(xiàn)有大多數(shù)基于模塊的網(wǎng)絡(luò)設(shè)計(jì)的方法無縫結(jié)合。論文還通過分析神經(jīng)網(wǎng)絡(luò)微分方程(NODE)的逼近性質(zhì)首次證明了具有一般跨層連接的神經(jīng)網(wǎng)絡(luò)的萬有逼近性質(zhì),并利用提出的框架設(shè)計(jì)了 ConvNext、ViT 的變種網(wǎng)絡(luò),取得了超越 baseline 的結(jié)果。論文被人工智能頂刊 TPAMI 接收。
- 論文:Designing Universally-Approximating Deep Neural Networks: A First-Order Optimization Approach
- 論文地址:https://ieeexplore.ieee.org/document/10477580
方法簡介
傳統(tǒng)的基于優(yōu)化的神經(jīng)網(wǎng)絡(luò)設(shè)計(jì)方法通常從一個(gè)具有顯式表示的目標(biāo)函數(shù)出發(fā),采用特定的優(yōu)化算法進(jìn)行求解,再將優(yōu)化迭代格式映射為神經(jīng)網(wǎng)絡(luò)架構(gòu),例如著名的 LISTA-NN 就是利用 LISTA 算法求解 LASSO 問題所得 [4],這種方法受限于目標(biāo)函數(shù)的顯式表達(dá)式,可設(shè)計(jì)得到的網(wǎng)絡(luò)結(jié)構(gòu)有限。一些研究者嘗試通過自定義目標(biāo)函數(shù),再利用算法展開等方法設(shè)計(jì)網(wǎng)絡(luò)結(jié)構(gòu),但他們也需要如權(quán)重綁定等與實(shí)際情況可能不符的假設(shè)。
論文提出的易于操作的網(wǎng)絡(luò)架構(gòu)設(shè)計(jì)方法從一階優(yōu)化算法的更新格式出發(fā),將梯度或鄰近點(diǎn)算法寫成如下的更新格式:
其中、
表示第 k 步更新時(shí)的(步長)系數(shù),再將梯度項(xiàng)替換為神經(jīng)網(wǎng)絡(luò)中的可學(xué)習(xí)模塊 T,即可得到 L 層神經(jīng)網(wǎng)絡(luò)的骨架:
整體方法框架見圖 1。
圖 1 網(wǎng)絡(luò)設(shè)計(jì)圖示
論文提出的方法可以啟發(fā)設(shè)計(jì) ResNet、DenseNet 等經(jīng)典網(wǎng)絡(luò),并且解決了傳統(tǒng)基于優(yōu)化設(shè)計(jì)網(wǎng)絡(luò)架構(gòu)的方法局限于特定目標(biāo)函數(shù)的問題。
模塊選取與架構(gòu)細(xì)節(jié)
該方法所設(shè)計(jì)的網(wǎng)絡(luò)模塊 T 只要求有包含兩層網(wǎng)絡(luò)結(jié)構(gòu),即,作為其子結(jié)構(gòu),即可保證所設(shè)計(jì)的網(wǎng)絡(luò)具有萬有逼近性質(zhì),其中所表達(dá)的層的寬度是有限的(即不隨逼近精度的提高而增長),整個(gè)網(wǎng)絡(luò)的萬有逼近性質(zhì)不是靠加寬
的層來獲得的。模塊 T 可以是 ResNet 中廣泛運(yùn)用的 pre-activation 塊,也可以是 Transformer 中的注意力 + 前饋層的結(jié)構(gòu)。T 中的激活函數(shù)可以是 ReLU、GeLU、Sigmoid 等常用激活函數(shù)。還可以根據(jù)具體任務(wù)在中添加對(duì)應(yīng)的歸一化層。另外,
時(shí),設(shè)計(jì)的網(wǎng)絡(luò)是隱式網(wǎng)絡(luò) [5],可以用不動(dòng)點(diǎn)迭代的方法逼近隱格式,或采用隱式微分(implicit differentiation)的方法求解梯度進(jìn)行更新。
通過等價(jià)表示設(shè)計(jì)更多網(wǎng)絡(luò)
該方法不要求同一種算法只能對(duì)應(yīng)一種結(jié)構(gòu),相反,該方法可以利用優(yōu)化問題的等價(jià)表示設(shè)計(jì)更多的網(wǎng)絡(luò)架構(gòu),體現(xiàn)其靈活性。例如,線性化交替方向乘子法通常用于求解約束優(yōu)化問題:通過令
即可得到一種可啟發(fā)網(wǎng)絡(luò)的更新迭代格式:
其啟發(fā)的網(wǎng)絡(luò)結(jié)構(gòu)可見圖 2。
圖 2 線性化交替方向乘子法啟發(fā)的網(wǎng)絡(luò)結(jié)構(gòu)
啟發(fā)的網(wǎng)絡(luò)具有萬有逼近性質(zhì)
對(duì)該方法設(shè)計(jì)的網(wǎng)絡(luò)架構(gòu),可以證明,在模塊滿足此前條件以及優(yōu)化算法(在一般情況下)穩(wěn)定、收斂的條件下,任意一階優(yōu)化算法啟發(fā)的神經(jīng)網(wǎng)絡(luò)在高維連續(xù)函數(shù)空間具有萬有逼近性質(zhì),并給出了逼近速度。論文首次在有限寬度設(shè)定下證明了具有一般跨層連接的神經(jīng)網(wǎng)絡(luò)的萬有逼近性質(zhì)(此前研究基本集中在 FCNN 和 ResNet,見表 1),論文主定理可簡略敘述如下:
主定理(簡略版):設(shè) A 是一個(gè)梯度型一階優(yōu)化算法。若算法 A 具有公式 (1) 中的更新格式,且滿足收斂性條件(優(yōu)化算法的常用步長選取均滿足收斂性條件。若在啟發(fā)網(wǎng)絡(luò)中均為可學(xué)習(xí)的,則可以不需要該條件),則由算法啟發(fā)的神經(jīng)網(wǎng)絡(luò):
在連續(xù)(向量值)函數(shù)空間以及范數(shù)
下具有萬有逼近性質(zhì),其中可學(xué)習(xí)模塊 T 只要有包含兩層形如
的結(jié)構(gòu)(σ 可以是常用的激活函數(shù))作為其子結(jié)構(gòu)都可以。
常用的 T 的結(jié)構(gòu)如:
1)卷積網(wǎng)絡(luò)中,pre-activation 塊:BN-ReLU-Conv-BN-ReLU-Conv (z),
2)Transformer 中:Attn (z) + MLP (z+Attn (z)).
主定理的證明利用了 NODE 的萬有逼近性質(zhì)以及線性多步方法的收斂性質(zhì),核心是證明優(yōu)化算法啟發(fā)設(shè)計(jì)的網(wǎng)絡(luò)結(jié)構(gòu)恰對(duì)應(yīng)一種收斂的線性多步方法對(duì)連續(xù)的 NODE 的離散化,從而啟發(fā)的網(wǎng)絡(luò) “繼承” 了 NODE 的逼近能力。在證明中,論文還給出了 NODE 逼近 d 維空間連續(xù)函數(shù)的逼近速度,解決了此前論文 [6] 的一個(gè)遺留問題。
表 1 此前萬有逼近性質(zhì)的研究基本集中在 FCNN 和 ResNet
實(shí)驗(yàn)結(jié)果
論文利用所提出的網(wǎng)絡(luò)架構(gòu)設(shè)計(jì)框架設(shè)計(jì)了 8 種顯式網(wǎng)絡(luò)和 3 種隱式網(wǎng)絡(luò)(稱為 OptDNN),網(wǎng)絡(luò)信息見表 2,并在嵌套環(huán)分離、函數(shù)逼近和圖像分類等問題上進(jìn)行了實(shí)驗(yàn)。論文還以 ResNet, DenseNet, ConvNext 以及 ViT 為 baseline,利用所提出的方法設(shè)計(jì)了改進(jìn)的 OptDNN,并在圖像分類的問題上進(jìn)行實(shí)驗(yàn),考慮準(zhǔn)確率和 FLOPs 兩個(gè)指標(biāo)。
表 2 所設(shè)計(jì)網(wǎng)絡(luò)的有關(guān)信息
首先,OptDNN 在嵌套環(huán)分離和函數(shù)逼近兩個(gè)問題上進(jìn)行實(shí)驗(yàn),以驗(yàn)證其萬有逼近性質(zhì)。在函數(shù)逼近問題中,分別考慮了逼近 parity function 和 Talgarsky function,前者可表示為二分類問題,后者則是回歸問題,這兩個(gè)問題都是淺層網(wǎng)絡(luò)難以逼近的問題。OptDNN 在嵌套環(huán)分離的實(shí)驗(yàn)結(jié)果如圖 3 所示,在函數(shù)逼近的實(shí)驗(yàn)結(jié)果如圖 3 所示,OptDNN 不僅取得了很好的分離 / 逼近結(jié)果,而且比作為 baseline 的 ResNet 取得了更大的分類間隔和更小的回歸誤差,足以驗(yàn)證 OptDNN 的萬有逼近性質(zhì)。
圖 3 OptNN 逼近 parity function
圖 4 OptNN 逼近 Talgarsky function
然后,OptDNN 分別在寬 - 淺和窄 - 深兩種設(shè)定下在 CIFAR 數(shù)據(jù)集上進(jìn)行了圖像分類任務(wù)的實(shí)驗(yàn),結(jié)果見表 3 與 4。實(shí)驗(yàn)均在較強(qiáng)的數(shù)據(jù)增強(qiáng)設(shè)定下進(jìn)行,可以看出,一些 OptDNN 在相同甚至更小的 FLOPs 開銷下取得了比 ResNet 更小的錯(cuò)誤率。論文還在 ResNet 和 DenseNet 設(shè)定下進(jìn)行了實(shí)驗(yàn),也取得了類似的實(shí)驗(yàn)結(jié)果。
表 3 OptDNN 在寬 - 淺設(shè)定下的實(shí)驗(yàn)結(jié)果
表 4 OptDNN 在窄 - 深設(shè)定下的實(shí)驗(yàn)結(jié)果
論文進(jìn)一步選取了此前表現(xiàn)較好的 OptDNN-APG2 網(wǎng)絡(luò),進(jìn)一步在 ConvNext 和 ViT 的設(shè)定下在 ImageNet 數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn),OptDNN-APG2 的網(wǎng)絡(luò)結(jié)構(gòu)見圖 5,實(shí)驗(yàn)結(jié)果表 5、6。OptDNN-APG2 取得了超過等寬 ConvNext、ViT 的準(zhǔn)確率,進(jìn)一步驗(yàn)證了該架構(gòu)設(shè)計(jì)方法的可靠性。
圖 5 OptDNN-APG2 的網(wǎng)絡(luò)結(jié)構(gòu)
表 5 OptDNN-APG2 在 ImageNet 上的性能比較
表 6 OptDNN-APG2 與等寬(isotropic)的 ConvNeXt 和 ViT 的性能比較
最后,論文依照 Proximal Gradient Descent 和 FISTA 等算法設(shè)計(jì)了 3 個(gè)隱式網(wǎng)絡(luò),并在 CIFAR 數(shù)據(jù)集上和顯式的 ResNet 以及一些常用的隱式網(wǎng)絡(luò)進(jìn)行了比較,實(shí)驗(yàn)結(jié)果見表 7。三個(gè)隱式網(wǎng)絡(luò)均取得了與先進(jìn)隱式網(wǎng)絡(luò)相當(dāng)?shù)膶?shí)驗(yàn)結(jié)果,也說明了方法的靈活性。
表 7 隱式網(wǎng)絡(luò)的性能比較
總結(jié)
神經(jīng)網(wǎng)絡(luò)架構(gòu)設(shè)計(jì)是深度學(xué)習(xí)中的核心問題之一。論文提出了一個(gè)利用一階優(yōu)化算法設(shè)計(jì)具有萬有逼近性質(zhì)保障的神經(jīng)網(wǎng)絡(luò)架構(gòu)的統(tǒng)一框架,拓展了基于優(yōu)化設(shè)計(jì)網(wǎng)絡(luò)架構(gòu)范式的方法。該方法可以與現(xiàn)有大部分聚焦網(wǎng)絡(luò)模塊的架構(gòu)設(shè)計(jì)方法相結(jié)合,可以在幾乎不增加計(jì)算量的情況下設(shè)計(jì)出高效的模型。在理論方面,論文證明了收斂的優(yōu)化算法誘導(dǎo)的網(wǎng)路架構(gòu)在溫和條件下即具有萬有逼近性質(zhì),并彌合了 NODE 和具有一般跨層連接網(wǎng)絡(luò)的表示能力。該方法還有望與 NAS、 SNN 架構(gòu)設(shè)計(jì)等領(lǐng)域結(jié)合,以設(shè)計(jì)更高效的網(wǎng)絡(luò)架構(gòu)。