預(yù)訓(xùn)練無(wú)需注意力,擴(kuò)展到4096個(gè)token不成問(wèn)題,與BERT相當(dāng)
Transformer 作為 NLP 預(yù)訓(xùn)練模型架構(gòu),能夠有效的在大型未標(biāo)記的數(shù)據(jù)上進(jìn)行學(xué)習(xí),研究已經(jīng)證明,Transformer 是自 BERT 以來(lái) NLP 任務(wù)的核心架構(gòu)。
最近的工作表明,狀態(tài)空間模型(SSM)是長(zhǎng)范圍序列建模有利的競(jìng)爭(zhēng)架構(gòu)。SSM 在語(yǔ)音生成和 Long Range Arena 基準(zhǔn)上取得了 SOTA 成果,甚至優(yōu)于 Transformer 架構(gòu)。除了提高準(zhǔn)確率之外,基于 SSM 的 routing 層也不會(huì)隨著序列長(zhǎng)度的增長(zhǎng)而呈現(xiàn)二次復(fù)雜性。
本文中,來(lái)自康奈爾大學(xué)、 DeepMind 等機(jī)構(gòu)的研究者提出了雙向門(mén)控 SSM (BiGS),用于無(wú)需注意力的預(yù)訓(xùn)練,其主要是將 SSM routing 與基于乘法門(mén)控(multiplicative gating)的架構(gòu)相結(jié)合。該研究發(fā)現(xiàn) SSM 本身在 NLP 的預(yù)訓(xùn)練中表現(xiàn)不佳,但集成到乘法門(mén)控架構(gòu)中后,下游準(zhǔn)確率便會(huì)提高。
實(shí)驗(yàn)表明,在受控設(shè)置下對(duì)相同數(shù)據(jù)進(jìn)行訓(xùn)練,BiGS 能夠與 BERT 模型的性能相匹配。通過(guò)在更長(zhǎng)的實(shí)例上進(jìn)行額外預(yù)訓(xùn)練,在將輸入序列擴(kuò)展到 4096 時(shí),模型還能保持線性時(shí)間。分析表明,乘法門(mén)控是必要的,它修復(fù)了 SSM 模型在變長(zhǎng)文本輸入上的一些特定問(wèn)題。
論文地址:https://arxiv.org/pdf/2212.10544.pdf
方法介紹
SSM 通過(guò)以下微分方程將連續(xù)輸入 u (t) 與輸出 y (t) 聯(lián)系起來(lái):
對(duì)于離散序列,SSM 參數(shù)被離散化,其過(guò)程可以近似為:
這個(gè)方程可以解釋為一個(gè)線性 RNN,其中 x_k 是一個(gè)隱藏狀態(tài)。y 也可以用卷積計(jì)算:
Gu 等人展示了一種在神經(jīng)網(wǎng)絡(luò)中使用 SSM 的有效方法,他們開(kāi)發(fā)了參數(shù)化 A 的方法,稱為 HiPPO,其產(chǎn)生了一個(gè)穩(wěn)定而高效的架構(gòu),稱為 S4。這保留了 SSM 對(duì)長(zhǎng)期序列建模的能力,同時(shí)比 RNN 訓(xùn)練更有效。最近,研究人員提出了 S4 的簡(jiǎn)化對(duì)角化版本,它通過(guò)對(duì)原始參數(shù)更簡(jiǎn)單的近似實(shí)現(xiàn)了類似的結(jié)果。在高層次上,基于 SSM 的 routing 為神經(jīng)網(wǎng)絡(luò)中的序列建模提供了一種替代方法,而無(wú)需二次計(jì)算的注意力成本。
預(yù)訓(xùn)練模型架構(gòu)
SSM 能取代預(yù)訓(xùn)練中的注意力嗎?為了回答這個(gè)問(wèn)題,該研究考慮了兩種不同的架構(gòu),如圖 1 所示的堆疊架構(gòu)(STACK)和乘法門(mén)控架構(gòu)(GATED)。
具有自注意力的堆疊架構(gòu)相當(dāng)于 BERT /transformer 模型,門(mén)控架構(gòu)是門(mén)控單元的雙向改編,最近也被用于單向 SSM。帶有乘法門(mén)控的 2 個(gè)序列塊(即前向和后向 SSM)夾在前饋層中。為了進(jìn)行公平比較,門(mén)控架構(gòu)的大小保持與堆疊架構(gòu)相當(dāng)。
圖 1:模型變量。STACK 是標(biāo)準(zhǔn) transformer 架構(gòu),GATED 為基于門(mén)控單元。對(duì)于 Routing 組件(虛線),該研究同時(shí)考慮雙向 SSM(如圖所示)和標(biāo)準(zhǔn)自注意力。門(mén)控(X)表示逐元素乘法。
實(shí)驗(yàn)結(jié)果
預(yù)訓(xùn)練
表 1 顯示了 GLUE 基準(zhǔn)測(cè)試中不同預(yù)訓(xùn)練模型的主要結(jié)果。BiGS 在 token 擴(kuò)展上復(fù)制了 BERT 的準(zhǔn)確率。這一結(jié)果表明,在這樣的計(jì)算預(yù)算下,SSM 可以復(fù)制預(yù)訓(xùn)練 transformer 模型的準(zhǔn)確率。這些結(jié)果明顯優(yōu)于其他基于非注意力的預(yù)訓(xùn)練模型。想要達(dá)到這個(gè)準(zhǔn)確率,乘法門(mén)控是必要的。在沒(méi)有門(mén)控的情況下,堆疊 SSM 的結(jié)果明顯更差。為了檢查這種優(yōu)勢(shì)是否主要來(lái)自于門(mén)控的使用,本文使用 GATE 架構(gòu)訓(xùn)練了一個(gè)基于注意力的模型;然而,結(jié)果顯示該模型的效果實(shí)際上低于 BERT。
表 1:GLUE 結(jié)果。(Top)在控制設(shè)置下,不同架構(gòu)和 routing 的比較。參見(jiàn)圖 2 了解詳細(xì)信息。(Bottom) 報(bào)告了基于 CNN、LSTM 和 FNet 的其他非注意力預(yù)訓(xùn)練模型的可比結(jié)果。
Long-Form 任務(wù)
表 2 結(jié)果顯示,可以將 SSM 與 Longformer EncoderDecoder (LED) 和 BART 進(jìn)行比較,但是,結(jié)果顯示它在遠(yuǎn)程任務(wù)中表現(xiàn)得也不錯(cuò),甚至更勝一籌。與其他兩種方法相比,SSM 的預(yù)訓(xùn)練數(shù)據(jù)要少得多。即使 SSM 不需要在這些長(zhǎng)度上進(jìn)行近似,長(zhǎng)格式也依舊很重要。
表 2:SCROLLS Encoder 測(cè)試結(jié)果?;€模型都是編碼器 —— 解碼器模型,一個(gè)基于 Longformer (LED),另一個(gè)基于 BART。輸入的長(zhǎng)度有截?cái)唷?/em>
更多內(nèi)容請(qǐng)查看原論文。