參數量1/50,Meta發布110億參數模型,擊敗谷歌PaLM
我們可以將大型語言模型(LLMs)理解為小樣本學習者,其能夠通過很少的例子就能學習新任務,甚至僅通過簡單的說明就能學習,其中對模型參數量和訓練數據的大小進行擴展是模型擁有泛化能力的關鍵。LLMs 的這種提升歸功于更強大算力和存儲能力。直觀上,推理能力的提高會帶來更好的泛化,從而減少樣本的學習,然而目前還不清楚有效的小樣本學習在多大程度上需要大量的模型參數知識。?
目前為止檢索增強模型還沒有展示出令人信服的小樣本學習能力。論文中,來自 Meta AI Research 等機構的研究者提出小樣本學習是否需要模型在其參數中存儲大量信息,以及存儲是否可以與泛化解耦。他們提出 Atlas,其是檢索增強語言模型的一種,擁有很強的小樣本學習能力,即使參數量低于目前其它強大的小樣本學習模型。
模型采用非參數存儲,即使用基于大型外部非靜態知識源上的神經檢索器去增強參數語言模型。除了存儲能力,此類架構在適應性、可解釋性和效率方面都存在優勢,因此很有吸引力。
論文地址:https://arxiv.org/pdf/2208.03299.pdf??
Atlas 檢索相關文檔是基于 Contriever 雙編碼器架構的通用密度檢索器,檢索文件時基于當前上下文檢索相關文件。檢索到的文檔與當前上下文一起交由序列到序列模型處理,該模型使用 Fusion-in-Decoder 架構生成相應的輸出。
作者研究了不同技術對訓練 Atlas 在一系列下游任務(包括問答和事實檢查)上的小樣本數據集性能的影響。研究發現聯合預訓練組件對于小樣本性能至關重要,作者評估了許多現有和新穎的預訓練任務和方案,Atlas 在小樣本和資源豐富的環境中都擁有強大的下游性能。
在只有 11B 個參數的情況下,Atlas 使用 64 個訓練示例在 NaturalQuestions(NQ)上實現了 42.4% 準確率,比 540B 參數模型 PaLM( 39.6% ) 高出近 3 個百分點,在全數據集設置中(Full)達到 64.0% 準確率。
?
Yann LeCun 表示:Atlas 是一個不太大的語言模型(11B 參數),在問答和事實核查方面擊敗了「大家伙」。Atlas 主要區別在于它可以從語料庫中檢索事實。
方法概覽?
Atlas 遵循文本到文本框架。這意味著所有任務的總體框架是:系統以文本查詢作為輸入,生成文本輸出。例如,在問答任務情況下,查詢對應于問題,模型需要生成答案。在分類任務情況下,查詢對應于文本輸入,模型生成類標簽,即標簽對應的詞。圖 2 中的 KILT 基準給出了更多下游任務的示例。許多自然語言處理任務需要知識,Atlas 的目標是通過檢索增強標準文本到文本模型,因為檢索可能對于模型小樣本場景下的學習能力至關重要。
架構
Atlas 模型基于兩個子模型:檢索器和語言模型。當執行任務時,從問答到生成 Wikipedia 文章,模型首先通過檢索器從大型文本語料庫中檢索前 k 個相關文檔。然后,這些文檔連同查詢一起作為輸入給到語言模型,生成輸出。檢索器和語言模型都基于預訓練的 transformer 網絡,下面對它們做詳細介紹。?
檢索器:Atlas 的檢索器模塊基于 Contriever,這是一種基于連續密度嵌入的信息檢索技術。Contriever 使用雙編碼器架構,其中查詢和文檔由 transformer 編碼器獨立嵌入。平均池化應用于最后一層的輸出,以獲得每個查詢或文檔的向量表示。然后通過計算查詢和每個文檔間的相互嵌入的點積,得到它們的相似度分數。Contriever 模型使用 MoCo 對比損失進行預訓練,并且僅使用無監督數據。密度檢索器的優點之一是查詢和文檔編碼器都可以在沒有文檔注釋的情況下使用標準技術(例如梯度下降和蒸餾)進行訓練。?
語言模型:對于語言模型,Atlas 依賴于 T5 序列到序列架構。模型同時也依賴于對序列到序列模型的 Fusion-in-Decoder 修改,并在編碼器中獨立處理每個文檔。之后模型連接對應于不同文檔的編碼器的輸出,并在解碼器中對單個序列執行 cross-attention。模型把查詢連接到編碼器中的每個文檔。在語言模型中處理檢索到的文檔的另一種方法是將查詢和所有文檔連接起來,并使用這個長序列作為模型的輸入。但這種方法可擴展性較差,即它不會隨著文檔的數量增多而擴展,因為編碼器中的自注意力機制會導致 O(n^2)的時間復雜度(這里 n 是文檔數量)。
實驗結果?
作者在 NaturalQuestions 和 TriviaQA 這兩個開放域問答基準上評估 Atlas。并且分別使用 64 個樣例的小樣本數據集和完整的訓練集,與之前的工作進行比較,詳細對比見下表。
NaturalQuestions 和 TriviaQA 的 64-shot 問答中表現最優。特別是它優于更大的模型 (PaLM) 或需要更多訓練計算的模型(Chinchilla)。在使用全量的訓練集時,Atlas 也能到最優結果,例如把 NaturalQuestions 的準確率從 55.9% 提高到 60.4%。這個結果是在 Atlas 的默認設置下,使用由 CCNet 和 2021 年 12 月 Wikipedia 語料庫組成的索引獲得的。 下表展示了在事實核查數據集 FEVER 上的測試結果。?
Atlas 在 64-shot 情況下,訓練樣例采樣自全量訓練集。Atlas 達到了 64.3% 的準確率。而在 15-shot 的情況下,從每個類中統一采樣 5 個樣例,與 Gopher 結果比較,Atlas 準確率為 56.2%,比 Gopher 高 5.1 個百分點。在全量訓練集上微調 Atlas 模型,達到 78% 的準確率,比 ProoFVer 低 1.5%。ProoFVer 使用專門的架構,用句子級注釋訓練的檢索器,并由維基百科語料庫提供與 FEVER 一起發布,而 Atlas 從 CCNet 和 2021 年 12 月的維基百科轉儲中檢索。當給 Atlas 由 FEVER Wikipedia 語料庫組成的索引,Atlas 取得了 80.1% 最優水平。
為驗證 Atlas 的性能,Atlas 在 KILT 進行了評估,KILT 是由幾個不同的知識密集型任務組成的基準。下表展示了測試集的結果。?
Atlas 64-shot 在實驗中遠遠超過隨機算法,甚至與排行榜上的某些經過微調的模型不相上下。如在 FEVER 上,Atlas 64-shot 僅落后 Sphere、SEAL 和 Re2G 2-2.5 分,而在 zero-shot RE 上的表現優于 Sphere 和 SEAL。在全量數據集上,Atlas 在 3 個數據集的表現與最好的模型相差在 3% 以內,但在其余 5 個數據集中是表現最好的。