破解ChatGPT驚人耗電!DeepMind新算法訓練提效13倍,能耗暴降10倍
ChatGPT早已成為世界耗能大戶:一天用掉超50萬度電,相當于1.7萬個美國家庭的用電量!
然而,大模型對能源的吞噬,遠不僅如此。
國際能源署(IEA)預測,從2022年到2026年,數(shù)據(jù)中心的用電量將翻一番。
隨著AI計算需求的膨脹,還需要用水來冷卻計算系統(tǒng)。研究稱,微軟用水量從2021年到22年飆升了34%,ChatGPT每處理5-50個提示就會消耗接近半升水。
針對這種現(xiàn)狀,我們有更好的解決策略嗎?
最近,谷歌DeepMind研究團隊提出了一種加快AI訓練的新方法——多模態(tài)對比學習與聯(lián)合示例選擇(JEST),大大減少了所需的計算資源和時間。
JEST以13倍更少的迭代次數(shù),以及10倍更少的計算量,超越了最先進的模型!
論文地址:https://arxiv.org/pdf/2406.17711
預訓練的參考模型,已經(jīng)學習了什么樣的數(shù)據(jù)是有「優(yōu)質(zhì)的」或「有用的」。然后通過模型,來引導數(shù)據(jù)選擇那些精心篩選過的小型數(shù)據(jù)集。
這一發(fā)現(xiàn)揭示了,數(shù)據(jù)篩選水平可以作為評判Scaling Law的一個新維度。
網(wǎng)友激動表示,「我沒想到這么快就會發(fā)生。模型能夠自主選擇訓練數(shù)據(jù)的能力是巨大的,因為它使訓練變得顯著更容易,你不再需要猜測什么是高質(zhì)量的訓練數(shù)據(jù),你有一個能夠『理解』什么樣的數(shù)據(jù)對自身學習最有價值的模型」。
前谷歌、蘋果軟件工程師稱贊道,這項研究非常令人印象深刻。
從「超級batch」中篩選數(shù)據(jù)
無論是語言、視覺還是多模態(tài)模型,數(shù)據(jù)質(zhì)量是預訓練性能的重要驅(qū)動因素。比如Phi-3、Gemma 2等模型的成功讓我們看到了,更少、更高質(zhì)量的數(shù)據(jù)有可能實現(xiàn)更強大的性能。
要篩選出高質(zhì)量的數(shù)據(jù),數(shù)據(jù)管道的建立就成為重要的工作。現(xiàn)有的方法大體可以分為兩種:1)手動管理 2)基于模型的數(shù)據(jù)管理,用正在訓練模型的特征選擇高質(zhì)量數(shù)據(jù)。
前者成本高昂且難以擴展,后者則有望為多模態(tài)LLM實現(xiàn)Scaling Law。
然而,現(xiàn)有方法忽略了一個事實。
如果僅在單個數(shù)據(jù)點的層面進行篩選,就沒有考慮到數(shù)據(jù)集以及batch的總體組成。畢竟,訓練數(shù)據(jù)是以batch為單位,數(shù)據(jù)點之間的依賴性不可忽視。
許多計算機視覺的研究都曾表明,hard negatives(表達空間中相近但標簽不同的樣本)相比可被平凡解的數(shù)據(jù)簇,能提供更有效的學習信號。
那么如何讓模型以batch為單位篩選數(shù)據(jù)呢?
論文提出的JEST算法正是要解決這個問題,原理很好理解:就是直接從「超級batch」中篩選出「子batch」。
技術介紹
用數(shù)學語言來描述這個問題,就是從大小為B的「超級batch」??中提取出與學習最相關的子batch ?={????,??∈[1,…,??]}???,過濾比率可以寫作??=1???/??。
之前的優(yōu)先采樣(prioritized sampling)會使用基于模型的評分函數(shù)對每個數(shù)據(jù)點打分,再按比例采樣。JEST則直接對整個子batch評分,再按照batch級別的分數(shù)采樣。
一種最直觀的啟發(fā)式方法就是在現(xiàn)有模型參數(shù) ?? : ??hard?(?|??)=??(?|??) 中,直接選擇損失值最高的batch,這種方法可被稱之為「硬學習」(hard learner)。
這種方法具有丟棄瑣碎數(shù)據(jù)的理想屬性,已被證明適用于小型、干凈的數(shù)據(jù)集;然而對于較大、較少管理的數(shù)據(jù)集往往弊大于利,因為它依舊會采樣到噪聲數(shù)據(jù)。
另一種方法常用于多模態(tài),使用具有參數(shù) ???:??^easy?(?|???)=???(?|???) 的參考模型為預訓練模型采樣數(shù)據(jù)。但作者依舊否定了這個方案,因為它無法直接反映模型當前的狀態(tài),可能過度依賴參考模型的選擇,而且不易于擴展。
最后,論文選擇借鑒ICML 2022年的一篇論文中提到的方法,將上述兩方面的評分結合起來:??^learn?(?|??,???)=??hard?(?|??)+??^easy?(?|???)=??(?|??)???(?|???),并將這種啟發(fā)式方法稱為「可學習性評分」(learnability score)。
其中,batch上的損失值??(?|??)是各數(shù)據(jù)點之和,使用sigmoid對比損失函數(shù)計算(sigmoid-contrastive loss),因為相比softmax對比損失而言,它的擴展性更強。
由于batch上的對比損失可以分解為每個樣本的條件損失之和,因此可學習性評分可被分解為單個樣本可學習性評分???(??|??,???,?)之和,寫作:
使用的順序采樣方法則受到了block Gibbs采樣的啟發(fā)。在第n次迭代、對第B_n個batch進行采樣時,依據(jù)如下概率公式對塊{X_k}進行無替換采樣:
將X_k塊添加到B_n中來更新當前采樣的batch,直至迭代數(shù)n=N時終止。算法的總體流程如下圖所示:
實驗中發(fā)現(xiàn),使用迭代數(shù)N=16且每次迭代時獨立采樣b/N=2048個樣本時,就足以恢復出學習性非常高的batch。
可學習性評分中涉及到使用參考模型為數(shù)據(jù)點打分,之前的方法慣常使用額外的小型模型,但這會增加每次迭代的計算成本,降低總體FLOP效率增益。
因此論文使用了在線模型近似的方法以及效率較高的FlexiViT架構,只使用降低分辨率的32×32的patch來評估「超級batch」,與全分辨率、patch大小為16×16的方法相比減少了72%的FLOP,以及67%的掛鐘時間(wall-clock time)。
此外,論文還提出了進行多分辨率訓練的技巧。將每個batch隨機分成兩半,使用不同分辨率編碼后再拼接起來,提升了評分過程和訓練的效率。
下圖詳細描述了全分辨率JEST和多分辨率Flexi-JEST方法的偽代碼實現(xiàn)。
所有JEST實驗都在WebLI數(shù)據(jù)集上運行,包含經(jīng)過寬松過濾的十億規(guī)模的英語圖像-文本對,參考模型的訓練則使用其中經(jīng)過高質(zhì)量過濾100M大小的子集(被稱為WebLI-curated)。
在WebLI的基礎上,作者還額外從網(wǎng)絡上抓取了6億個文本-圖像對并經(jīng)過同樣強度的過濾,組成WebLI-curated++數(shù)據(jù)集訓練參考模型,拓展出JEST++/FlexiJEST++方法,來探索對數(shù)據(jù)管理的擴展。
論文所報告的平均性能包括4個多模態(tài)規(guī)范基準:ImageNet 0-Shot和10-Shot 分類以及COCO圖像到文本和文本到圖像的top-1檢索。
實驗結果
圖1中可以看到,使用JEST或FlexiJEST方法的最明顯優(yōu)勢就是效率提升。
左圖中,相比原有的SigLIP基線模型,JEST++可以在訓練數(shù)據(jù)量減少13.1×的情況下達到相同準確率。即使考慮到額外引入的打分成本,也有近10×的FLOP效率提升(中圖)。
右圖展現(xiàn)了JEST++/FlexiJEST++(綠色)與先前方法(灰色)的比較,相比CLIP、EVA-CLIP經(jīng)典模型實現(xiàn)了計算成本和性能的雙重提升。
左圖和中圖的平均準確率由8個下游任務得出,右圖性能由ImageNet和COCO基準測試得出
產(chǎn)生可學習batch
研究人員首先評估了JEST在選擇可學習batch方面的效果。
為了直觀地理解這一方法,作者們先將可學習性矩陣進行可視化,即學習模型和參考模型之間,對batch中所有示例對的損失差異。
JEST就是按照示例子矩陣的可學習性總和比例進行采樣。
由于矩陣明顯非對角關系(圖2,左),獨立選擇顯然是次優(yōu)的。
經(jīng)過少量迭代(對應于用N=16個塊填充batch),作者發(fā)現(xiàn)子batch的可學習性快速增加,達到了需要數(shù)千次迭代的暴力吉布斯采樣(Gibbs sampling )所提取batch的可學習性(圖2,中)。
對于0.5、0.8和0.9的過濾比例,他們從大小分別為65,536、163,840和327,680的超級batch中選擇32,768個示例的子batch。
在圖2右側,研究者還發(fā)現(xiàn)子batch的可學習性隨著更大的過濾比例而增加。
總之,JEST算法是在訓練過程中選擇高度可學習batch的有效,且高效的方法。
加速多模態(tài)學習
接下來,研究人員使用JEST算法選擇的可學習batch,檢驗訓練模型的效果。
所有實驗都使用在WebLI-curated上訓練的參考模型,這是一個ViT-B/16和Bert-B圖像-文本雙編碼器,30億訓練樣本,采用sigmoid對比損失函數(shù)。
圖3(左)顯示了在訓練過程中多個下游任務(ImageNet 0-Shot/10-Shot準確率和COCO圖像到文本/文本到圖像檢索)的平均性能。
結果還發(fā)現(xiàn),JEST顯著加速了學習過程。
在使用50%、80%和90%的過濾比例時,分別只需20億、10億和6.7億訓練樣本就達到了30億均勻基準的最終性能。
在更大的過濾比例下,坐著觀察到類似于更大batch size時的訓練不穩(wěn)定性,需要修改Adam優(yōu)化器(β2 = 0.95)以穩(wěn)定訓練,這表明JEST的數(shù)據(jù)篩選可以被視為增加了有效batch size。
在最終性能方面,當過濾90%的數(shù)據(jù)時,JEST也帶來了高達6%的顯著提升(圖3,中間,藍色曲線)。
值得注意的是,這種scaling行為這種性能提升在獨立樣本選擇方法中,并沒有觀察到。(圖3,中間,橙色曲線)。
最后,研究者還評估JEST是否也改善了,除可學習性之外的其他優(yōu)先標準。
圖3右側顯示了使用easy-reference優(yōu)先選擇的模型在不同過濾比例下的性能。
與基于可學習性的優(yōu)先選擇一致,JEST仍優(yōu)于獨立樣本選擇,特別是在高過濾比例下(在這種情況下,獨立樣本選擇導致性能下降)。
優(yōu)先選擇具有最高損失的數(shù)據(jù)產(chǎn)生了較小的收益,并且隨著過濾更多數(shù)據(jù)而更快地退化(圖10)。
由于基于可學習性的JEST產(chǎn)生了最佳的scaling行為,研究人員在后續(xù)實驗中保留了這一標準。
多分辨率訓練和在線batch選擇之間的協(xié)同效應
隨著數(shù)據(jù)batch中被過濾的比例增加,基于可學習性評分的JEST變得更加高效。
然而,評分的成本會帶來顯著的提升:過濾超級batch 80%的數(shù)據(jù)會導致每次迭代的浮點運算量是IID訓練的4倍,或者在緩存參考模型得分時是2.3倍。
盡管JEST在訓練迭代次數(shù)方面(以下簡稱「訓練效率」)顯著提高了效率,但額外的評分浮點運算降低了其相對于IID基準的計算效率(圖1,左vs右)。
因此,作者還研究了一種計算效率更高的變體,稱為Flexi-JEST,它使用多分辨率訓練和低分辨率評分,將總開銷降低到僅比基準高10%(圖4,左)。
這些近似方法對性能有什么影響?
正如預期的那樣,F(xiàn)lexi-JEST的每次迭代性能相對于JEST有所下降,但仍然比IID有顯著的加速(圖1,左;圖4,中)。
然而,考慮到總浮點運算量的減少,每次迭代性能的下降是非常有利的:最好的Flexi-JEST模型與40B Siglip運行產(chǎn)生相同的平均性能,但浮點運算量減少了9.9倍,比全分辨率JEST少2倍(圖1,右;圖4,中)。
這些實驗表明了多分辨率訓練和聯(lián)合示例選擇之間的協(xié)同效應,前者為加速后者提供了高效和準確的評分能力。
實驗結果,還指出了數(shù)據(jù)策劃策略的帕累托前沿(pareto front)。
如果以計算為代價來最大化訓練速度或訓練效率,全分辨率JEST方法相對于可比的IID訓練運行,可以產(chǎn)生高達13倍的加速。
實現(xiàn)強大數(shù)據(jù)質(zhì)量引導
可學習性評分的核心是,一個在人類選擇的小型、精心篩選的數(shù)據(jù)集上,訓練的參考模型。
JEST的性能如何隨不同的篩選策略(在質(zhì)量和數(shù)量之間權衡)而變化?
此外,JEST訓練的改進是否與參考模型的性能相關,還是這些指標是分離的?
理解質(zhì)量與數(shù)量的權衡
研究人員探索了三種規(guī)模的數(shù)據(jù)篩選,每種都是原始WebLI數(shù)據(jù)集的一個子集:
- 弱篩選(十億級規(guī)模):使用圖像-文本對齊(ITA)過濾器。
- 中度篩選(3億級規(guī)模):使用ITA過濾器或文本質(zhì)量(TQ)過濾器。
- 強篩選(1億級規(guī)模):結合使用TQ、ITA和額外的圖像質(zhì)量(aesthetic)過濾器。
在整個過程中,作者將這個強篩選子集稱為「WebLI-curated」。
然后,他們在這四個WebLI子集上,各訓練10個epoch的標準SigLIP編碼器,并將它們用作在全WebLI數(shù)據(jù)集上進行JEST訓練的參考模型。
在不同的數(shù)據(jù)篩選方法中,參考模型的性能和JEST的性能似乎是解耦的(甚至可能是反相關的;圖5,左)。
雖然增加篩選(和減少數(shù)據(jù)集大小)會產(chǎn)生較弱的模型,但當它們被用作JEST預訓練的參考模型時,卻產(chǎn)生了相反的效果:
使用強篩選參考模型的JEST獲得了2.7%的改進,中度篩選獲得了1.5%的改進,弱篩選獲得了0.3%的改進。
擴展數(shù)據(jù)篩選
假設參考模型性能與JEST性能之間的普遍解耦,可能僅僅是由數(shù)據(jù)篩選所施加的數(shù)據(jù)集大小限制造成的。
為了理解這種效果,研究人員在WebLI-curated上訓練了5個參考模型,同時改變所見的總樣本數(shù)(從2.5億到30億)。
在這種情況下,圖5(右)顯示了改進的參考模型與更好的JEST預訓練之間存在著顯著的相關性。
這表明「解耦」現(xiàn)象主要可以歸因于參考模型因篩選后數(shù)據(jù)集大小減少而導致的飽和。
此外,研究人員還注意到,當數(shù)據(jù)集達到飽和時,圖5(右)中的相關性開始崩解,即在10個epoch或者看到10億個樣本之后。
這些結果表明,JEST可能會從進一步擴大參考數(shù)據(jù)集的數(shù)據(jù)篩選中獲益。
鑒于使用WebLI-curated++對數(shù)據(jù)進行擴展整理能顯著提高參考模型的性能,作者提出了是否有必要在原始WebLI數(shù)據(jù)集上進行預訓練的問題。
然而,在評估參考模型在不同數(shù)據(jù)集上的性能時,卻發(fā)現(xiàn):雖然它在2個下游任務上的性能優(yōu)于WebLI預訓練,但在其他6個任務上的性能,以及平均性能都明顯低于WebLI預訓練(表 5)。
與現(xiàn)有數(shù)據(jù)比較
最后,論文應用JEST++在公開的LAION-2B數(shù)據(jù)集上進行預訓練,刪除了其中不安全的圖像-文本對,但沒有進行其他的預先過濾。
這個數(shù)據(jù)規(guī)模相比的SOTA方法DBP減少了4×,但JEST++依舊遠遠超過了所有之前的離線數(shù)據(jù)管理方法。
簡化數(shù)據(jù)管理
之前提到過,用于預訓練的WebLI-curated是原始數(shù)據(jù)集WebLI過濾后得到的,以求篩選出高質(zhì)量的圖像-文本對齊的數(shù)據(jù)。
如表3所示,這種離線數(shù)據(jù)管理流程對IID(獨立同分布)訓練方法的性能至關重要,但JEST++則表現(xiàn)出了對預過濾流程的魯棒性。即使沒有過濾,JEST++的性能也沒有出現(xiàn)明顯下滑,降低了模型對基礎數(shù)據(jù)集的要求。
結論和局限性
總體來說,JEST方法展現(xiàn)出了「數(shù)據(jù)質(zhì)量引導」(data quality bootstrapping)方法的巨大潛力,即使用小規(guī)模精選數(shù)據(jù)集來指導對更大的、未經(jīng)管理的數(shù)據(jù)集的學習。
最近的研究表明,在下游任務未知時,靜態(tài)數(shù)據(jù)集的過濾會限制模型性能。這篇論文的結果則表明,相比單獨選擇樣本的方法,在線構建batch能提高預訓練的效率。
無論是使用JEST參考模型對數(shù)據(jù)集進行預評分,還是通過可學習性評分來根據(jù)模型需求進行動態(tài)調(diào)整,都可以成為通用基礎數(shù)據(jù)集的更有效率的替代方案。
論文的最后,作者也提出了該方法的局限性。雖然JEST同時實現(xiàn)了性能增益和訓練成本降低,但依舊依賴于小型、精心管理的參考數(shù)據(jù)集,它指定了未經(jīng)管理的更大數(shù)據(jù)集中優(yōu)先考慮的分布。
因此,未來的工作可以探索一種方法,從指定的下游任務中如何推斷出參考數(shù)據(jù)集的組成和分布。