ICLR 2024 | 雞生蛋蛋生雞?再論生成數據能否幫助模型訓練
隨著生成模型(如 ChatGPT、擴散模型)飛速發展,一方面,生成數據質量越來越高,到了以假亂真的程度;另一方面,隨著模型越來越大,也使得人類世界的真實數據即將枯竭。
面對這一處境,一個近期的研究熱度是,能否利用生成模型生成的假數據來輔助學習?學界對此也產生了許多爭論:到底是可以左腳踩右腳(bootsrap)地實現 weak-to-strong 的不斷提升,還是像雞生蛋、蛋生雞一樣,只不過是徒勞無功?
在近期 ICLR 2024 工作中,北大王奕森團隊針對這一「數據擴充」(Data Inflation)問題展開了深入研究。
他們針對對比學習(如 SimCLR、DINO、CLIP)這一常見的自監督學習場景,從理論和實驗兩方面分析了生成數據對于表示學習能力的影響。為了控制變量,他們保證生成模型和表示學習都只能使用同一個(無監督)真實數據集進行訓練,避免了擴充數據本身帶來的收益。
論文題目:
Do Generated Data Always Help Contrastive Learning?
論文鏈接:
??https://arxiv.org/abs/2403.12448??
代碼鏈接:
??https://github.com/PKU-ML/adainf??
他們發現,在這種情況下,生成數據并不總是對表示學習有幫助,在很多情況下甚至有害。比如,將 DDPM 的數據直接加入 CIFAR-10 訓練,反而導致分類準確率下降超過 1%(前人工作 [1] 也有類似發現:用生成數據擴充 ImageNet 后 ResNet-50 的分類準確率下降了 2.69%)。進一步分析表明,有兩個關鍵因素影響了生成數據的收益:
1. 真實數據和生成數據的比例。從人的角度來看,生成數據似乎以假亂真,但對于模型訓練而言并非如此。他們發現,真實數據與生成數據的混合比例在 10:1 附近時達到最優,也就是說,1 個真實數據的「訓練價值」約等于 10 個生成數據。這側面說明了二者的差異。
2. 訓練策略的設計。他們發現,在使用生成數據進行訓練時,如果維持原有的訓練參數,則模型幾乎沒有提升。相反,如果隨著數據集的擴充,而相應降低模型訓練所使用的數據增廣的強度,則可以獲得顯著提升。
針對這兩個核心觀察,本文還從自監督理論出發,解釋了他們內在的產生原因,并進而分析了數據量、數據質量與數據增廣強度之間的權衡取舍。
▲ 圖1 (a): 數據擴充流程 ;(b): 不同擴充策略下的對比學習性能
真實數據比生成數據的「訓練價值」
數據擴充最直觀的一個影響因素是生成數據的質量問題。下圖 2(a)表明,生成數據質量越高,對比學習的下游泛化能力越好,但遺憾的是即使是目前的 SOTA 生成模型 STF,也只讓模型的 Linear Accuracy(在特征上應用線性分類器的分類準確率)比此前僅上升 0.02%。
由于真實圖片包含更豐富、準確的信息,因此擴充后的數據集中真實數據和生成數據的地位不應該相同。本文研究通過在混合時對真實數據復制 N 倍的方式,對真實數據和生成數據進行重加權(Reweighting)。
圖 2(b)表明,混合比例在 10:1 時達到最優(weak augmentation)。本文進一步從理論上分析了重加權的作用,在此不做展開。
▲ 圖2 (a) 生成數據質量對對比學習的影響; (b) 數據重賦權對對比學習的影響
數據增廣與數據擴充,如何權衡?
在對比學習中,數據增強(Data Augmentation)的選取至關重要。通常來說,自監督學習需要使用較強的數據增強(如裁切、掩碼等)來學習的數據表示。為了區分,本文將生成數據視為數據擴充(Data Inflation),二者的區別是,數據擴充是擴大原始數據集的大小,而數據增廣是對每個原始樣本,在訓練過程中進行隨機增強。
直觀上看,數據擴充和數據增廣都會提升數據多樣性但數據增廣可能會改變圖像的語義信息(下圖 3),因此當數據擴充提供了足夠的數據時,便可以減弱數據增廣從而減小因圖像語義信息的改變帶來的誤差。
▲ 圖3. 數據增強可能改變圖片的語義信息
文中構造了四個不同規模的數據集:CIFAR-10、Half CIFAR-10(CIFAR-10 的一半)、CIFAR-10+10 萬張生成圖片、CIFAR-10+100 萬張生成圖片,通過改變 random resized crop(RRC)來反應不同的數據增廣強度。
下圖 4 中表明最優數據增廣強度隨著數據規模的增大而減小(Half CIFAR-10:0.02,CIFAR-10:0.08,CIFAR-10+0.1M:0.20,CIFAR-10+1M:0.30)。因此當進行數據擴充時,數據增廣強度需要減弱。也就是說,只有當二者搭配得當,才能充分發揮生成數據的作用。
▲ 圖4. 數據量和數據增廣強度的關系
基于增廣圖的理論理解
▲ 數據擴充后的下游泛化誤差上界
為了進一步刻畫數據擴充和數據增廣之間的關系,本文從圖的角度來建模對比學習:將數據增強產生的每個樣本視為圖 上的節點,并定義同一樣本產生的數據增廣樣本之間存在一條邊,這樣便在樣本空間構建了一個圖,稱為增廣圖(Augmentation Graph)[2,3]。
這是理解自監督學習的經典理論之一,根據這一建模,對比學習的下游泛化誤差上界可表示為
,其中
表示由于數據增強造成的標簽錯誤(labeling error),
表示增廣圖拉普拉斯矩陣的第
小的特征值,用于反應圖的連通性。
數據擴充和數據增廣對和
- 數據擴充:不會改變標簽錯誤
,但可以提升圖的連通性(
增大)(下圖 5 (a))。
- 數據增廣:數據增廣強度增加,會使得 labeling error
增大(圖 5 (b)),但同時使不同樣本之間的交疊部分增加,即增廣圖的連通性增強(
增大)(圖 5 (c))。
因此當數據擴充提升數據規模從而提供了足夠的圖的連通性時,為了進一步減小下游泛化誤差,可以減弱數據增廣強度從而使得 減小。反之數據規模比較小時,則需要更強的數據增強去獲得更好的圖的連通性。也就是說,數據擴充和數據增強在對比學習中存在互補作用,當數據擴充后,對應的最優數據增廣強度減小(圖 5(d))。
▲ 圖 5 數據擴充和數據增廣對 labeling error 和圖
的連通性的影響
基于以上的理解,論文提出自適應的數據擴充 Adaptive Inflation(AdaInf),根據生成數據的質量、大小,動態調整對比學習算法。其中,最重要的兩個指導原則是 1)真實數據和生成數據需賦予不同權重,生成數據質量越差權重應該越小;2)數據量增大后,應該減弱數據增廣強度,減少數據增強的負面作用。
實驗結果
本文主要考慮生成數據的規模遠大于真實數據的應用場景。為了在計算能力有限的情況下分析這一場景,作者主要考慮 CIFAR 數據集,因為可以在該數據集上采樣大量圖片。
以 CIFAR-10 為例,其中包含 5 萬真實訓練樣本,作者利用生成模型(GAN 或擴散模型)為它們添加 100 萬生成數據。以 10:1 的比例混合之后,作者將 CIFAR 數據集的總規模擴充到 150 萬。為了公平比較,本文保證全訓練過程中,生成模型也只能獲取 5 萬無監督數據。作者采用 SimCLR 作為默認方法并保持默認參數。
▲ 表1. 不同模型和不同數據集下的對比學習線性探測性能
本文在圖像識別任務上表 1 表明,AdaInf 在不同的對比學習模型和不同數據集上的性能顯著好于沒有數據擴充(No Inflation)或者直接進行數據擴充(Vanilla Inflation)。
僅使用基礎的 SimCLR 方法,AdaInf 就可以將 ResNet-18 上的自監督性能從 91.56 提升到 93.42,超越了大部分「魔改」的自監督學習方法,達到 Sota 水平。這進一步驗證了「數據為王」的規律,展示了 scaling 的潛力。
消融實驗:本文在下表 2 (a)中研究了 AdaInf 的組成部分:生成數據、數據重賦權、數據弱增廣。結果表明三者的重要性為數據弱增廣 > 數據重賦權 > 生成數據。這反映了數據擴充和數據增廣之間的相互作用對于對比學習的影響很大。
應用場景:作者進一步發現, AdaInf 可以很好地應用的數據缺乏的場景下。如表 2 (b)所示,當 CIFAR-10 每個類僅有 500 個樣本時,AdaInf 可以獲得更明顯的提升。
▲ 表2 (a) 消融實驗 (b) 數據匱乏場景下的應用
更多文章細節,請參考原文。
本文轉自 PaperWeekly ,作者:讓你更懂AI的
原文鏈接:??https://mp.weixin.qq.com/s/3iHewRj_IIgor_SIedbWjA?
