?如今,轉換器(Transformers)成為大多數先進的自然語言處理(NLP)和計算機視覺(CV)體系結構中的關鍵模塊。然而,表格式數據領域仍然主要以梯度提升決策樹(GBDT)算法為主導。于是,有人試圖彌合這一差距。其中,第一篇基于轉換器的表格數據建模論文是由Huang等人于2020年發表的論文《TabTransformer:使用上下文嵌入的表格數據建模》。
本文旨在提供該論文內容的基本展示,同時將深入探討TabTransformer模型的實現細節,并向您展示如何針對我們自己的數據來具體使用TabTransformer。
一、論文概述
上述論文的主要思想是,如果使用轉換器將常規的分類嵌入轉換為上下文嵌入,那么,常規的多層感知器(MLP)的性能將會得到顯著提高。接下來,讓我們更為深入地理解這一描述。
1.分類嵌入(Categorical Embeddings)
在深度學習模型中,使用分類特征的經典方法是訓練其嵌入性。這意味著,每個類別值都有一個唯一的密集型向量表示,并且可以傳遞給下一層。例如,由下圖您可以看到,每個分類特征都使用一個四維數組表示。然后,這些嵌入與數字特征串聯,并用作MLP的輸入。
帶有分類嵌入的MLP
2.上下文嵌入(Contextual Embeddings)
論文作者認為,分類嵌入缺乏上下文含義,即它們并沒有對分類變量之間的任何交互和關系信息進行編碼。為了將嵌入內容更加具體化,有人建議使用NLP領域當前所使用的轉換器來實現這一目的。
TabTransformer轉換器中的上下文嵌入
為了以可視化方式形象地展示上述想法,我們不妨考慮下面這個訓練后得到的上下文嵌入圖像。其中,突出顯示了兩個分類特征:關系(黑色)和婚姻狀況(藍色)。這些特征是相關的;所以,“已婚(Married)”、“丈夫(Husband)”和“妻子(Wife)”的值應該在向量空間中彼此接近,即使它們來自不同的變量。
經訓練后的TabTransformer轉換器嵌入結果示例
通過上圖中經過訓練的上下文嵌入結果,我們可以看到,“已婚(Married)”的婚姻狀況更接近“丈夫(Husband)”和“妻子(Wife)”的關系水平,而“未結婚(non-married)”的分類值則來自右側的單獨數據簇。這種類型的上下文使這樣的嵌入更加有用,而使用簡單形式的類別嵌入技術是不可能實現這種效果的。
3.TabTransformer架構
為了達到上述目的,論文作者提出了以下架構:
TabTransformer轉換器架構示意圖
(摘取自Huang等人2020年發表的論文)
我們可以將此體系結構分解為5個步驟:
- 標準化數字特征并向前傳遞
- 嵌入分類特征
- 嵌入經過N次轉換器塊處理,以便獲得上下文嵌入
- 把上下文分類嵌入與數字特征進行串聯
- 通過MLP進行串聯獲得所需的預測
雖然模型架構非常簡單,但論文作者表示,添加轉換器層可以顯著提高計算性能。當然,所有的“魔術”發生在這些轉換器塊內部;所以,接下來讓我們更加詳細地研究一下其中的實現過程。
4.轉換器
轉換器(Transformer)架構示意
(選自Vaswani等人于2017年發表的論文)
您可能以前見過轉換器架構,但為了快速介紹起見,請記住該轉換器是由編碼器和解碼器兩部分組成(見上圖)。對于TabTransformer,我們只關心將輸入的嵌入內容上下文化的編碼器部分(解碼器部分將這些嵌入內容轉換為最終輸出結果)。但它到底是如何做到的呢?答案是——多頭注意力機制。
5.多頭注意力機制(Multi-head-attention)
引用我最喜歡的關于注意力機制的文章的描述,是這樣的:
“自我關注(self attention)背后的關鍵概念是,這種機制允許神經網絡學習如何在輸入序列的各個片段之間以最好的路由方案進行信息調度。”
換句話說,自我關注(self-attention)有助于模型找出在表示某個單詞/類別時,輸入的哪些部分更重要,哪些部分相對不重要。為此,我強烈建議您閱讀一下上面引用的這篇文章,以便對自我關注為什么如此有效有一個更為直觀的理解。
多頭注意力機制
(選自Vaswani等人于2017年發表的論文)
注意力是通過3個學習過的矩陣來計算的——Q、K和V,它們代表查詢(Query)、鍵(Key)和值(Value)。首先,我們將矩陣Q和K相乘得到注意力矩陣。該矩陣被縮放并通過softmax層傳遞。然后,我們將其乘以V矩陣,得出最終值。為了更直觀地理解起見,請考慮下面的示意圖,它顯示了我們如何使用矩陣Q、K和V實現從輸入嵌入轉換到上下文嵌入。
自我關注流程可視化
通過重復該過程h次(使用不同的Q、K、V矩陣),我們就能夠得到多個上下文嵌入,它們形成我們最終的多頭注意力。
6.簡短回顧
讓我們總結一下上面所介紹的內容:
- 簡單的分類嵌入不包含上下文信息
- 通過轉換器編碼器傳遞分類嵌入,我們就能夠將嵌入上下文化
- 轉換器部分能夠將嵌入上下文化,因為它使用了多頭注意力機制
- 多頭注意力機制在編碼變量時使用矩陣Q、K和V來尋找有用的相互作用和相關性信息
- 在TabTransformer中,被上下文化的嵌入與數字輸入相串聯,并通過一個簡單的MLP輸出預測
雖然TabTransformer背后的想法很簡單,但您可能需要一些時間才能掌握注意力機制。因此,我強烈建議您重新閱讀以上解釋。如果您感到有些迷茫,請認真閱讀本文中所有建議的鏈接相關內容。我保證,做到這些后,您就不難搞明白注意力機制的原理了。
7.試驗結果展示
結果數據(選自Huang等人2020年發表的論文)
根據報告的結果,TabTransformer轉換器優于所有其他深度學習表格模型,此外,它接近GBDT的性能水平,這非常令人鼓舞。該模型對缺失數據和噪聲數據也相對穩健,并且在半監督環境下優于其他模型。然而,這些數據集顯然不是詳盡無遺的,正如以后發表的一些相關論文所證實的那樣,仍有很大的改進空間。
二、構建我們自己的示例程序
現在,讓我們最終來確定一下如何將模型應用于我們自己的數據。接下來的示例數據取自著名的Tabular Playground Kaggle比賽。為了方便使用TabTransformer轉換器,我創建了一個tabtransformertf包。它可以使用如下pip命令進行安裝:
并允許我們使用該模型,而無需進行大量預處理。
1.數據預處理
第一步是設置適當的數據類型,并將我們的訓練和驗證數據轉換為TF數據集。其中,前面安裝的軟件包中就提供了一個很好的實用程序可以做到這一點。
下一步是為分類數據準備預處理層。該分類數據稍后將被傳遞給我們的主模型。
這就是預處理!現在,我們可以開始構建模型了。
2.構建TabTransformer模型
初始化模型很容易。其中,有幾個參數需要指定,但最重要的幾個參數是:embeding_dim、depth和heads。所有參數都是在超參數調整后選擇的。
模型初始化后,我們可以像任何其他Keras模型一樣安裝它。訓練參數也可以調整,所以可以隨意調整學習速度和提前停止。
3.評價
競賽中最關鍵的指標是ROC AUC。因此,讓我們將其與PR AUC指標一起輸出來評估一下模型的性能。
您也可以自己給測試集評分,然后將結果值提交給Kaggle官方。我現在選擇的這個解決方案使我躋身前35%,這并不壞,但也不太好。那么,為什么TabTransfromer在上述方案中表現不佳呢?可能有以下幾個原因:
- 數據集太小,而深度學習模型以需要大量數據著稱
- TabTransformer很容易在表格式數據示例領域出現過擬合
- 沒有足夠的分類特征使模型有用
三、結論
本文探討了TabTransformer背后的主要思想,并展示了如何使用Tabtransformertf包來具體應用此轉換器。
歸納起來看,TabTransformer的確是一種有趣的體系結構,它在當時的表現明顯優于大多數深度表格模型。它的主要優點是將分類嵌入語境化,從而增強其表達能力。它使用在分類特征上的多頭注意力機制來實現這一點,而這是在表格數據領域使用轉換器的第一個應用實例。
TabTransformer體系結構的一個明顯缺點是,數字特征被簡單地傳遞到最終的MLP層。因此,它們沒有語境化,它們的價值也沒有在分類嵌入中得到解釋。在下一篇文章中,我將探討如何修復此缺陷并進一步提高性能。
譯者介紹
朱先忠,51CTO社區編輯,51CTO專家博客、講師,濰坊一所高校計算機教師,自由編程界老兵一枚。
原文鏈接:https://towardsdatascience.com/transformers-for-tabular-data-tabtransformer-deep-dive-5fb2438da820?source=collection_home---------4----------------------------