TabR:檢索增強能否讓深度學習在表格數據上超過梯度增強模型?
這是一篇7月新發布的論文,他提出了使用自然語言處理的檢索增強Retrieval Augmented技術,目的是讓深度學習在表格數據上超過梯度增強模型。
檢索增強一直是NLP中研究的一個方向,但是引入了檢索增強的表格深度學習模型在當前實現與非基于檢索的模型相比幾乎沒有改進。所以論文作者提出了一個新的TabR模型,模型通過增加一個類似注意力的檢索組件來改進現有模型。據說,這種注意力機制的細節可以顯著提高表格數據任務的性能。TabR模型在表格數據上的平均性能優于其他DL模型,在幾個數據集上設置了新的標準,在某些情況下甚至超過了GBDT模型,特別是在通常被視為GBDT友好的數據集上。
TabR
表格數據集通常被表示為特征和標簽對{(xi, yi)},其中xi和yi分別是第i個對象的特征和標簽。一般有三種類型的主要任務:二元分類、多類分類和回歸。
對于表格數據我們會將數據集分為訓練部分、驗證部分和測試部分,模型對“輸入”或“目標”對象進行預測。當使用檢索技術時,檢索是在一組“上下文候選”或“候選”中完成的,被檢索的對象稱為“上下文對象”或簡稱為“上下文”。同一組候選對象用于所有輸入對象。
論文的實驗設置涉及調優和評估協議,其中需要超參數調優和基于驗證集性能的早期停止。然后在15個隨機種子的平均測試集上測試最佳超參數,并在算法比較中考慮標準偏差。
論文作者的目標是將檢索功能集成到傳統的前饋網絡中。該過程包括通過編碼器傳遞目標對象及其上下文候選者,然后檢索組件會對目標對象進行的表示,最后預測器進行預測。
編碼器和預測器模塊很簡單簡單,因為它們不是工作的重點。檢索模塊對目標對象的表示以及候選對象的表示和標簽進行操作。這個模塊可以看作是注意力機制的一般化版本。
這個過程包括幾個步驟:
- 如果編碼器包含至少一個塊,則將表示進行規范化;
- 根據與目標對象的相似性定義上下文對象;
- 基于softmax函數對上下文對象的相似性分配權重;
- 定義上下文對象的值;
- 使用值和權重輸出加權聚合。
上下文大小設置為一個較大的值96,softmax函數會自動選擇有效的上下文大小。
檢索模塊是最重要的部分
作者探討了檢索模塊的不同實現,特別是相似度模塊和值模塊。并且說明了是通過一下幾個步驟得到最終的模型。
1、作者評估了傳統注意力的相似性和值模塊,發現該配置與多層感知器(MLP)相似,因此不能證明使用檢索組件是合理的。
2、然后他們將上下文標簽添加到值模塊中,但發現這并沒有改進,這表明傳統注意力的相似性模塊可能是瓶頸。
3、為了改進相似度模塊,作者刪除了查詢的概念,并用L2距離替換點積。這種調整使得幾個數據集上性能的顯著躍升。
4、值模塊也進行改進,靈感來自最近提出的DNNR(用于回歸問題的kNN算法的廣義版本)。新的值模塊帶來了進一步的性能改進。
5、最后,作者創建模型TabR。在相似性模塊中省略縮放項,不包括目標對象在其自身的上下文中(使用交叉注意),平均而言會得到更好的結果。
生成的TabR模型為基于檢索的表格深度學習問題提供了一種健壯的方法。
作者也強調了TabR模型的兩個主要局限性:
與所有檢索增強模型一樣,從應用程序的角度來看,使用真實的訓練對象進行預測可能會帶來一些問題,例如隱私和道德問題。
TabR的檢索組件雖然比以前的工作更有效,但會產生明顯的開銷。所以它可能無法有效地擴展以處理真正的大型數據集。
實驗結果
作者將TabR與現有的檢索增強解決方案和最先進的參數模型進行比較。除了完全配置的TabR,他們還使用了一個簡化版本,TabR- s,它不使用特征嵌入,只有一個線性編碼器和一個塊預測器。
與全參數深度學習模型的比較表明,TabR在幾個數據集上優于大多數模型,除了MI數據集,在其他數據集也很有競爭力。在許多數據集上,它比多層感知器(MLP)提供了顯著的提升。
與GBDT模型相比,調整后的TabR在幾個數據集上也有明顯的改進,并且在其他數據集上保持競爭力(除了MI數據集),并且TabR的平均表現也優于GBDT模型。
總之,TabR將自己確立為表格數據問題的強大深度學習解決方案,展示了強大的平均性能,并在幾個數據集上設置了新的基準。它的基于檢索的方法具有良好的潛力,并且在某些數據集上可以明顯優于梯度增強的決策樹。
一些研究
1、凍結上下文以更快地訓練TabR
在TabR的原始實現中,由于需要對所有候選對象進行編碼并計算每個訓練批次的相似度,因此在大型數據集上的訓練可能很慢。作者提到在完整的“Weather prediction”數據集上訓練一個TabR需要18個多小時,該數據集有300多萬個對象。
作者注意到在訓練過程中,平均訓練對象的上下文(即,根據相似度模塊S,前m個候選對象及其分布)趨于穩定,這為優化提供了機會。在一定數量的epoch之后,他們提出了一個“上下文凍結”,即最后一次計算所有訓練對象的最新上下文,然后在其余的訓練中重用。
這種簡單的技術可以加速TabR的訓練,并且不會在指標上造成重大損失。在上面提到的完整的“Weather prediction”數據集上,它使速度提高了近7倍(將訓練時間從18小時9分鐘減少到3小時15分鐘),同時仍然保持有競爭力的均方根誤差(RMSE)值。
2、用新的訓練數據更新TabR不需要再訓練(初步探索)
在現實世界的場景中,在機器學習模型已經訓練完之后,通常會收到新的、看不見的訓練數據。作者測試了TabR在不需要再訓練的情況下合并新數據的能力,方法是將新數據添加到候選檢索集中。
他們使用完整的“Weather prediction”數據集進行了這個測試。結果表明在線更新可以有效地將新數據整合到訓練好的TabR模型中。這種方法可以通過在數據子集上訓練模型并從完整數據集中檢索模型來將TabR擴展到更大的數據集。
3、使用檢索組件增強XGBoost
作者試圖通過結合類似于TabR中的檢索組件來提高XGBoost的性能。這種方法涉及在原始特征空間中找到與給定輸入對象最接近的96個訓練對象(匹配TabR的上下文大小)。然后對這些最近鄰的特征和標簽進行平均,將標簽按原樣用于回歸任務,并將其轉換為用于分類任務的單一編碼。
將這些平均數據與目標對象的特征和標簽連接起來,形成XGBoost的新輸入向量。但是該策略并沒有顯著提高XGBoost的性能。試圖改變鄰居的數量也沒有產生任何顯著的改善。
總結
深度學習模型在表格類數據上一直沒有超越梯度增強模型,TabR還在這個方向繼續努力。