基于 boosting 原理訓練深層殘差神經網絡
1. 背景
1.1 Boosting
Boosting[1] 是一種訓練 Ensemble 模型的經典方法,其中一種具體實現 GBDT 更是廣泛應用在各類問題上。介紹boost的文章很多,這里不再贅述。簡單而言,boosting 方法是通過特定的準則,逐個訓練一系列弱分類,這些弱分類加權構成一個強分類器(圖1)。
1.2 殘差網絡
殘差網絡[2]目前是圖像分類等任務上***的模型,也被應用到語音識別等領域。其中核心是 skip connect 或者說 shortcut(圖2)。這種結構使梯度更易容向后傳導,因此,使訓練更深的網絡變得可行。
在之前的博文作為 Ensemble 模型的 Residual Network中,我們知道,一些學者將殘差網絡視一種特殊的 Ensemble 模型[3,4]。論文作者之一是Robert Schapire(剛注意到已經加入微軟研究院),AdaBoost的提出者(和 Yoav Freund一起)。Ensemble 的觀點基本算是主流觀點(之一)了。
2. 訓練方法
2.1 框架
- 殘差網絡
即這是一個線性分類器(Logistic Regression)。
- hypothesis module
其中 $C$ 為分類任務的類別數。
- weak module classifier
其中 $\alpha$ 為標量,也即 $h$ 是相鄰兩層 hypothesis 的線性組合。***層沒有更低層,因此,可以視為有一個虛擬的低層,$\alpha_0=0$ 并且 $、o_0(x)=0$。
- 將殘差網絡顯示表示為 ensemble
令殘差網絡的***輸出為 $F(x)$,并接合上述定義,顯然有:
這里用到了裂項求和的技巧(telescoping sum),因此作者稱提出的算法為 telescoping sum boosting.
我們只需要逐級(residual block)訓練殘差網絡,效果上便等同于訓練了一系列弱分類的 enemble。其中,除了訓練殘差網絡的權值外,還要訓練一些輔助的參數——各層的 $\alpha$ 及 $W$(訓練完成后即可丟棄)。
2.2 Telescoping Sum Boosting(裂項求和提升)
文章正文以二分類問題為例展開,我們更關心多分類問題,相關算法在附錄部分。文章給出的偽代碼說明相當清楚,直接復制如下:
其中,$\gamma_t$ 是一個標量;$C_t$ 是一個 m 乘 C (樣本數乘類別數)的矩陣,$C_t(i, j)$ 表示其中第 $i$ 行第 $j$ 列的元素。
需要特別說明的是,$st(x, l)$ 表示 $s_t(x)$的第 $l$ 個元素(此處符號用的略隨意:-);而 $st(x) = \sum{\tau=1}^t h\tau(x) = \alpha_t \cdot o_t(x) $。
與算法3中類似,$f(g(x_i), l)$ 表示 $f(g(x_i))$ 的第 $l$ 個元素,$g(x_i, y_i)$ 表示 $g(x_i)$ 的第 $i$ 個元素。
顯然 Algorithm 4 給的最小化問題可以用 SGD 優化,也可以數值的方法求解([1] 4.3 節)。
3. 理論
理論分部沒有詳細看。大體上,作者證明了 BoostResNet 保留為 boost 算法是優點:1)誤差隨網絡深度(即弱分類器數量)指數減?。?)抗過擬合性,模型復雜度承網絡深度線性增長。詳細可參見論文。
4. 討論
BoostResNet ***的特點是逐層訓練,這樣有一系列好處:
- 減少內存占用(Memory Efficient),使得訓練大型的深層網絡成為可能。(目前我們也只能在CIFAR上訓練千層的殘差網絡,過過干癮)
- 減少計算量(Computationally Efficient),每一級都只訓練一個淺層模型。
- 因為只需要訓練淺層模型,在優化方法上可以有更多的選擇(非SGD方法)。
- 另外,網絡層數可以依據訓練情況動態的確定。
4.2 一些疑問
文章應該和逐層訓練的殘差網絡(固定或不固定前面各層的權值)進行比較多,而不是僅僅比較所謂的 e2eResNet。
作者這 1.1 節***也提到,訓練框架不限于 ResNet,甚至不限于神經網絡。不知道用來訓練普通深度模型效果會怎樣,競爭 layer-wise pretraining 現在已經顯得有點過時了。
References
- Schapire & Freund. Boosting: Foundations and Algorithms. MIT.
- He et al. Deep Residual Learning for Image Recognition.
- Veit et al. Residual Networks Behave Like Ensembles of Relatively Shallow Networks.
- Xie et al. Aggregated Residual Transformations for Deep Neural Networks.