Transformer后繼有模!MSRA提出全新大模型基礎(chǔ)架構(gòu):推理速度8倍提升,內(nèi)存占用減少70%
本文經(jīng)AI新媒體量子位(公眾號(hào)ID:QbitAI)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請(qǐng)聯(lián)系出處。
微軟大模型新架構(gòu),正式向Transformer發(fā)起挑戰(zhàn)!
論文標(biāo)題明晃晃地寫道:
Retentive Network(RetNet):大模型領(lǐng)域Transformer的繼任者。
圖片
論文提出新的Retention機(jī)制來代替Attention。來自微軟亞研院和清華的研究人員,毫不諱言“野心”,大膽放話:
RetNet實(shí)現(xiàn)了良好的擴(kuò)展結(jié)果、并行訓(xùn)練、低成本部署和高效推理。
這些特性使這一基礎(chǔ)架構(gòu),成為大語言模型中Transformer的有力繼承者。
而實(shí)驗(yàn)數(shù)據(jù)也顯示,在語言建模任務(wù)上:
- RetNet可以達(dá)到與Transformer相當(dāng)?shù)睦Щ蠖龋╬erplexity)
- 推理速度達(dá)8.4倍
- 內(nèi)存占用減少70%
- 具有良好的擴(kuò)展性
并且當(dāng)模型大小大于一定規(guī)模時(shí),RetNet表現(xiàn)會(huì)優(yōu)于Transformer。
圖片
Transformer果真“后繼有模”了?具體詳情,一起來看。
解決“不可能三角”
Transformer在大語言模型中的重要性毋庸置疑。無論是OpenAI的GPT系列,還是谷歌的PaLM、Meta的LLaMA,都是基于Transformer打造。
但Transformer也并非完美無缺:其并行處理機(jī)制是以低效推理為代價(jià)的,每個(gè)步驟的復(fù)雜度為O(N);Transformer是內(nèi)存密集型模型,序列越長(zhǎng),占用的內(nèi)存越多。
在此之前,大家也不是沒想過繼續(xù)改進(jìn)Transformer。但主要的幾種研究方向都有些顧此失彼:
線性attention可以降低推理成本,但性能較差;
循環(huán)神經(jīng)網(wǎng)絡(luò)則無法進(jìn)行并行訓(xùn)練。
也就是說,這些神經(jīng)網(wǎng)絡(luò)架構(gòu)面前擺著一個(gè)“不可能三角”,三個(gè)角代表的分別是:并行訓(xùn)練、低成本推理和良好的擴(kuò)展性能。
圖片
RetNet的研究人員想做的,就是化不可能為可能。
具體而言,RetNet在Transformer的基礎(chǔ)上,使用多尺度保持(retention)機(jī)制替代了標(biāo)準(zhǔn)的自注意力機(jī)制。
與標(biāo)準(zhǔn)自注意力機(jī)制相比,保持機(jī)制有幾大特點(diǎn):
引入位置相關(guān)的指數(shù)衰減項(xiàng)取代softmax,簡(jiǎn)化了計(jì)算,同時(shí)使前步的信息以衰減的形式保留下來。
引入復(fù)數(shù)空間表達(dá)位置信息,取代絕對(duì)或相對(duì)位置編碼,容易轉(zhuǎn)換為遞歸形式。
另外,保持機(jī)制使用多尺度的衰減率,增加了模型的表達(dá)能力,并利用GroupNorm的縮放不變性來提高retention層的數(shù)值精度。
圖片
△RetNet的雙重表示
每個(gè)RetNet塊包含兩個(gè)模塊:多尺度保持(MSR)模塊和前饋網(wǎng)絡(luò)(FFN)模塊。
保持機(jī)制支持以三種形式表示序列:
- 并行
- 遞歸
- 分塊遞歸,即并行表示和遞歸表示的混合形式,將輸入序列劃分為塊,在塊內(nèi)按照并行表示進(jìn)行計(jì)算,在塊間遵循遞歸表示。
其中,并行表示使RetNet可以像Transformer一樣高效地利用GPU進(jìn)行并行訓(xùn)練。
遞歸表示實(shí)現(xiàn)了O(1)的推理復(fù)雜度,降低了內(nèi)存占用和延遲。
分塊遞歸則可以更高效地處理長(zhǎng)序列。
這樣一來,RetNet就使得“不可能三角”成為可能。以下為RetNet與其他基礎(chǔ)架構(gòu)的對(duì)比結(jié)果:
在語言建模任務(wù)上的實(shí)驗(yàn)結(jié)果,進(jìn)一步證明了RetNet的有效性。
結(jié)果顯示,RetNet可以達(dá)到與Transformer相似的困惑度(PPL,評(píng)價(jià)語言模型好壞的指標(biāo),越小越好)。
同時(shí),在模型參數(shù)為70億、輸入序列長(zhǎng)度為8k的情況下,RetNet的推理速度能達(dá)到Transformer的8.4倍,內(nèi)存占用減少70%。
在訓(xùn)練過程中,RetNet在內(nèi)存節(jié)省和加速效果方面,也比標(biāo)準(zhǔn)Transformer+FlashAttention表現(xiàn)更好,分別達(dá)到25-50%和7倍。
值得一提的是,RetNet的推理成本與序列長(zhǎng)度無關(guān),推理延遲對(duì)批量大小不敏感,允許高吞吐量。
圖片
另外,當(dāng)模型參數(shù)規(guī)模大于20億時(shí),RetNet的表現(xiàn)會(huì)優(yōu)于Transformer。
研究團(tuán)隊(duì)
RetNet的研究團(tuán)隊(duì),來自微軟亞研院和清華大學(xué)。
共同一作為孫宇濤和董力。
孫宇濤,清華大學(xué)計(jì)算機(jī)系本科,現(xiàn)在在微軟亞研院實(shí)習(xí)。
董力,微軟亞研院研究員。他也是此前引發(fā)大量關(guān)注的“能記住10億token的Transformer”的論文作者之一。
RetNet論文的通訊作者是韋福如。他是微軟亞洲研究院全球研究合伙人,10億token Transformer亦是來自他的研究團(tuán)隊(duì)。
論文地址:https://arxiv.org/abs/2307.08621