BERT在CNN上也能用?字節跳動研究成果中選ICLR 2023 Spotlight
?
如何在卷積神經網絡上運行 BERT?
你可以直接用 SparK —— 字節跳動技術團隊提出的稀疏層次化掩碼建模 (Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling),近期已被人工智能頂會收錄為 Spotlight 焦點論文:?
論文鏈接:?
??https://arxiv.org/pdf/2301.03580???
開源代碼:?
??https://github.com/keyu-tian/SparK???
這也是 BERT 在卷積神經網絡 (CNN) 上的首次成功。先來感受一下 SparK 在預訓練中的表現吧。?
輸入一張殘缺不全的圖片:?
還原出一只小狗:?
另一張殘缺圖片:?
原來是貝果三明治:?
其他場景也可實現圖片復原:?
BERT 和 Transformer 的天作之合?
“任何偉大的行動和思想,都有一個微不足道的開始。”?
在 BERT 預訓練算法的背后,是簡潔而深刻的設計。 BERT 使用“完形填空”:將一句話中的若干詞語進行隨機刪除,并讓模型學會恢復。?
BERT 非常依賴于 NLP 領域的核心模型 —— Transformer。?
Transformer 由于生來就適合處理可變長度的序列數據(例如一個英文句子),所以能輕松應付 BERT 完形填空的“隨機刪除”。?
視覺領域的 CNN 也想享受 BERT:兩個挑戰何在??
回顧計算機視覺發展史,卷積神經網絡模型凝練了平移等變性、多尺度結構等等眾多經典模型精華,可謂CV 界的中流砥柱。但與 Transformer 大相徑庭的是,CNN 天生無法適應經過完形填空“挖空”的、充滿“隨機孔洞”的數據,因此乍一看無法享受到 BERT 預訓練的紅利。?
上圖 a. 展示的是 MAE (Masked Autoencoders are Scalable Visual Learners) 這項工作,由于使用的是 Transformer 模型而非 CNN 模型,其可以靈活應對經過帶有空洞的輸入,乃與 BERT “天作之合”。?
而右圖 b. 則展示了一種粗暴融合 BERT 和 CNN 模型的方式——即把全部空洞區域“涂黑”,并將這張“黑馬賽克”圖輸入到 CNN 中,結果可想而知,會帶來嚴重的像素強度分布偏移問題,并導致很差的性能 (后文有驗證)。這就是阻礙 BERT 在 CNN 上成功應用的挑戰一。?
此外,作者團隊還指出,源自 NLP 領域的 BERT 算法,天然不具備“多尺度”的特點,而多尺度的金字塔結構在計算機視覺的悠久歷史中可謂“金標準”。單尺度的 BERT,和天然多尺度的 CNN 之間的沖突,則是挑戰二。?
解決方案 SparK:稀疏且層次化的掩碼建模?
作者團隊提出了 SparK (Sparse and hierarchical masKed modeling) 來解決前文兩個挑戰。?
其一,受三維點云數據處理的啟發,作者團隊提出將經過掩碼操作 (挖空操作) 后的零碎圖片視為稀疏點云,并使用子流形稀疏卷積 (Submanifold Sparse Convolution) 來進行編碼。這就讓卷積網絡能夠自如處理隨機刪除后的圖像。?
其二,受 UNet 優雅設計的啟發,作者團隊自然地設計了一種帶有橫向連接的編碼器-解碼器模型,讓多尺度特征在模型的多層次之間流動,讓 BERT 徹底擁抱計算機視覺的多尺度黃金標準。?
至此,一種為卷積網絡 (CNN) 量身定制的稀疏的、多尺度的掩碼建模算法 SparK 誕生了。?
SparK 是通用的:其可被直接運用在任何卷積網絡上,而無需對它們的結構進行任何修改,或引入任何額外的組件——不論是我們耳熟能詳的經典 ResNet,還是近期的先進模型 ConvNeXt,均可直接從 SparK 中受益。?
從 ResNet 到 ConvNeXt:三大視覺任務性能提升?
作者團隊選擇了具代表性的兩個卷積模型家族 ResNet 和 ConvNeXt,并在圖像分類,目標檢測、實例分割任務上進行了性能測試。?
在經典 ResNet-50 模型上,SparK 作為唯一的生成式預訓練,達到了 State-of-the-art 水準:?
在 ConvNeXt 模型上,SparK 依舊領先。在預訓練前,ConvNeXt 與 Swin-Transformer 平分秋色;而經預訓練后,ConvNeXt 在三個任務上均壓倒性超過了 Swin-Transformer:?
當從小到大,在完整的模型家族上驗證 SparK,便可觀察到:?
無論模型的大與小、新與舊,均可從 SparK 中受益,且隨著模型尺寸/訓練開銷的增長,漲幅甚至更高,體現出 SparK 算法的擴放 (scaling) 能力:?
最后,作者團隊還設計了一個驗證性的消融實驗,從中可見稀疏掩碼和層次化結構第3行和第4行) 均是非常關鍵的設計,一旦缺失就會造成嚴重的性能衰退:?