CV開啟大模型時代!谷歌發布史上最大ViT:220億參數,視覺感知力直逼人類
Transformer無疑是促進自然語言處理領域繁榮的最大功臣,也是GPT-4等大規模語言模型的基礎架構。
不過相比語言模型動輒成千上萬億的參數量,計算機視覺領域吃到Transformer的紅利就沒那么多了,目前最大的視覺Transformer模型ViT-e的參數量還只有40億參數。
最近谷歌發布了一篇論文,研究人員提出了一種能夠高效且穩定訓練大規模Vision Transformers(ViT)模型的方法,成功將ViT的參數量提升到220億。
論文鏈接:https://arxiv.org/abs/2302.05442
為了實現模型的擴展,ViT-22B結合了其他語言模型(如PaLM模型)的思路,使用 QK 歸一化改進了訓練穩定性,提出了一種異步并行線性操作(asynchronous parallel linear operations)的新方法提升訓練效率,并且能夠在硬件效率更高的Cloud TPU上進行訓練。
在對ViT-22B模型進行實驗以評估下游任務性能時,ViT-22B也表現出類似大規模語言模型的能力,即隨著模型規模的擴大,性能也在不斷提升。
ViT-22B 還可以應用于PaLM-e中,與語言模型結合后的大模型可以顯著提升機器人任務的技術水平。
研究人員還進一步觀察到規模帶來的其他優勢,包括更好地平衡公平性和性能,在形狀/紋理偏見方面與人類視覺感知的一致性,以及更好的穩健性。
模型架構
ViT-22B 是一個基于Transformer架構的模型,和原版ViT架構相比,研究人員主要做了三處修改以提升訓練效率和訓練穩定性。
并行層(parallel layers)
ViT-22B并行執行注意力塊和MLP塊,而在原版Transformer中為順序執行。
PaLM模型的訓練也采用了這種方法,可以將大模型的訓練速度提高15%,并且性能沒有下降。
query/key (QK) normalization
在擴展ViT的過程中,研究人員在80億參數量的模型中觀察到,在訓練幾千步之后訓練損失開始發散(divergence),主要是由于注意力logits的數值過大引起的不穩定性,導致零熵的注意力權重(幾乎one-hot)。
為了解決這個問題,研究人員在點乘注意力計算之前對Query和Key使用LayerNorm
在80億參數模型上的實驗結果如下圖所示,歸一化可以緩解發散問題。
刪除QKV投影和LayerNorms上的偏置項
和PaLM模型一樣,ViT-22B從QKV投影中刪除了偏置項,并且在所有LayerNorms中都沒有偏置項(bias)和centering,使得硬件利用率提高了3%,并且質量沒有下降。
不過與PaLM不同的是,ViT-22B對(內部和外部)MLP稠密連接層使用了偏置項,可以觀察到質量得到了改善,并且速度也沒有下降。
ViT-22B的編碼器模塊中,嵌入層,包括抽取patches、線性投影和額外的位置嵌入都與原始ViT中使用的相同,并且使用多頭注意力pooling來聚合每個頭中的per-token表征。
ViT-22B的patch尺寸為14×14,圖像的分辨率為224×224(通過inception crop和隨機水平翻轉進行預處理)。
異步并聯線性運算(asynchronous parallel linear operations)
大規模的模型還需要分片(sharding),即將模型參數分布在不同的計算設備中,除此之外,研究人員還把激活(acctivations,輸入的中間表征)也進行分片。
因為輸入和矩陣本身都是分布在各種設備上的,即使是像矩陣乘法這樣簡單的操作也需要特別小心。
研究人員開發了一種稱為異步并行線性運算的方法,可以在矩陣乘法單元(在TPU 中占據絕大多數計算能力的單元)中計算時,同時對設備之間的激活和權值進行通信。
異步方法最小化了等待傳入通信的時間,從而提高了設備效率。
異步并行線性運算的目標是計算矩陣乘法 y = Ax,但矩陣 A 和激活 x 都分布在不同的設備上,需要通過跨設備的重疊通信和計算來實現這一點。矩陣 A 在設備之間進行列分片(column-shard),每個矩陣包含一個連續的切片,每個塊表示為 Aij,更多細節請看原始論文。
實驗結果
為了說明ViT-22B學習到的表征非常豐富,研究人員使用LiT-tuning訓練一個文本模型來生成一些表征用來對齊文本和圖像。
下面是用Parti 和 Imagen 生成的分布外(out-of-distribution)圖像得到的實驗結果,可以看到ViT-22B的zero-shot圖像分類泛化能力非常強,僅從web上爬取的自然圖像就能識別出沒見過的物體和場景。
論文中還討論了ViT-22B在視頻分類、深度估計和語義分割任務上的效果。
與人類目標識別對齊
為了驗證 ViT-22B 分類決策與人類分類決策的一致性,研究人員對 ViT-22B 進行了微調,對分布外(OOD)數據集的不同分辨率進行了微調,其中人類比較數據可通過model-vs-human toolbox獲得。
該工具箱主要衡量三個關鍵指標: 模型如何處理失真(準確性) ?人和模型的精度(精度差)有什么不同?人和模型的錯誤模式(錯誤一致性)有多相似?
形狀偏差評估(值越大代表更多的形狀偏差)。許多視覺模型具有低形狀/高紋理偏差,而在 ImageNet 上進行微調的 ViT-22B具有迄今為止在 ML 模型中記錄的最高形狀偏差,更接近于人類形狀偏見
實驗結果顯示,雖然并非所有的微調解決方案都表現得很好,但 ViT-22B 變體在所有三個指標上都達到了新高。
此外,ViT-22B 模型在視覺模型中也有最高的形狀偏差記錄。這意味著他們主要使用目標的形狀,而不是目標的紋理來進行分類決策,策略結果類似于人類的感知(其形狀偏差為96%)。
標準模型(例如,ResNet-50有20-30% 的形狀偏差)通常根據紋理來分類,而高形狀偏差的模型則傾向于關注形狀(下圖識別為貓),盡管人類和模型的感知之間仍然存在許多差異,但是 ViT-22B 顯示出與人類視覺對象識別更多的相似性。
貓還是大象?車還是鐘?鳥還是自行車?具有某個物體的形狀和另一個不同物體紋理的圖像,可用于測量形狀/紋理偏差
分布外(out-of-distribution)性能
測量 OOD 數據集的性能有助于評估模型泛化性。
在這個實驗中,研究人員構建了從 JFT 到 ImageNet 的標簽映射,以及從 ImageNet 到不同的分布外數據集(如 ObjectNet)的標簽映射。
對這些數據進行預訓練后的結果如下圖所示,然后在 ImageNet 上對模型進行完全微調。
可以觀察到縮放 Vision Transformers 可以提高 OOD 性能: 即使 ImageNet 的精度達到飽和,也可以看到 ObjectNet 上從 ViT-e 換成 ViT-22B 模型可以顯著提升性能。
線性探測Linear Probe
線性探測是一種將單個線性層置于凍結模型之上的技術,與完全微調相比,這種方法的訓練成本更低,設置起來也更容易。
在 ImageNet 上訓練的線性探測結果,在 ImageNet-Real,ImageNet-v2,ObjectNet,ImageNet-R 和 ImageNet-A 數據集上評估,提供高分辨率微調 ViT-e/14作為參考
從結果中可以觀察到,ViT-22B 的線性探測性能接近于使用高分辨率圖像對較小模型進行全面微調的最先進水平,其中具有較高分辨率的訓練通常要昂貴得多,但可以在許多任務上取得更好的結果。
蒸餾
利用蒸餾法,可以將較大模型的知識轉化為較小模型的知識,可以提升成本更高、運行速度更慢的大模型的運行效率。
從實驗結果中可以發現,ViT-22B 的知識可以遷移到更小的模型,如 ViT-B/16和 ViT-L/16,并在同等模型尺寸下在ImageNet上刷新了性能記錄。
公平性與偏見
機器學習模型容易受到意想不到的不公平偏見的影響,例如找到錯誤的相關性或者在各個子群體之間存在性能差距,研究人員發現,擴大模型規模有助于緩解這些問題。
首先,規模是一個有前景的權衡方式,即使模型經過訓練后再進行后處理,將其人口平等(demographic parity)水平控制在規定的、可容忍的水平之下,性能也會隨著規模的增加而提高。
上圖: 去偏前 CelebA 中每個子組的精度。下圖: y 軸顯示了在這個例子中突出顯示的兩個特定亞組(女性和男性)的表現的絕對差異。與較小的 ViT 模型相比,ViT-22B 在性能的差距很小。
更重要的是,這不僅適用于以準確性衡量性能的情況,而且適用于其他度量,例如校準,即對模型估計概率的真實性的統計測量,所有子群的分類隨著規模的增大而趨于改善,并且ViT-22B 降低了各子群之間的性能差距。
結論
研究人員提出了一個目前最大的視覺Transformer模型 ViT-22B,包含220億參數。
通過對原始模型架構進行微小但關鍵的修改后,實現了更高的硬件利用率和訓練穩定性,從而得到了一個在幾個基準測試上提高了模型的上限性能。
使用凍結模型生成嵌入,只需要在頂部訓練幾層,即可獲得很好的性能,并且評估結果進一步表明,與現有模型相比,ViT-22B 在形狀和紋理偏差方面顯示出與人類視知覺更多的相似性,并且在公平性和穩健性方面提供了優勢。