Transformers學習上下文強化學習的時間差分方法 原創
上下文學習指的是模型在推斷時學習能力,而不需要調整其參數。模型(例如transformers)的輸入包括上下文(即實例-標簽對)和查詢實例(即提示)。然后,模型能夠根據上下文在推斷期間為查詢實例輸出一個標簽。上下文學習的一個可能解釋是,(線性)transformers的前向傳播在上下文中實現了對實例-標簽對的梯度下降迭代。在本文中,研究人員通過構造證明了transformers在前向傳播中也能實現時間差異(TD)學習,并將這一現象稱為上下文TD。在訓練transformers使用多任務TD算法后展示了上下文TD的出現,并進行了理論分析。此外,研究人員證明了transformers具有足夠的表達能力,可以在前向傳播中實現許多其他策略評估算法,包括殘差梯度、帶有資格跟蹤的TD和平均獎勵TD。
上下文學習已經成為大型語言模型最顯著的能力之一。在上下文學習中,模型的輸入(即提示)包括上下文(即實例-標簽對)和一個查詢實例。然后,模型在推斷期間(即前向傳播)為查詢實例輸出一個標簽。模型輸入和輸出的一個示例可以是:
其中,“5 → number; a → letter”是包含兩個實例-標簽對的上下文,“6”是查詢實例。根據上下文,模型推斷查詢“6”的標簽為“number”。值得注意的是,整個過程在模型的推斷時間內完成,而不需要調整模型的參數。
在(1)中的示例說明了一個監督學習問題。在經典的機器學習框架中,這個監督學習問題通常通過首先基于上下文中的實例-標簽對訓練一個分類器來解決,使用諸如梯度下降之類的方法,然后要求分類器預測查詢實例的標簽。值得注意的是,研究表明,transformers能夠在前向傳播中實現這個梯度下降訓練過程,而不需要調整任何參數,為上下文學習提供了一個可能的解釋。
超越監督學習,智能涉及到順序決策,其中強化學習已經成為一個成功的范式。transformers在推斷期間能否執行上下文RL,以及如何執行?為了解決這些問題,研究人員從馬爾可夫獎勵過程MRP中的一個簡單評估問題開始。在MRP中,代理程序在每個時間步中從一個狀態轉換到另一個狀態。用(S0,S1,S2,...)表示代理訪問的狀態序列。在每個狀態下,代理程序會接收到一個獎勵。用(r(S0),r(S1),r(S2),...)表示代理程序在路途中接收到的獎勵序列。評估問題是估計值函數v,該函數計算每個狀態未來代理程序將收到的期望總(折扣)獎勵。所需的輸入輸出的一個示例可以是:
引人注目的是,上述任務與監督學習根本不同,因為目標是預測值v(s),而不是即時獎勵r(s)。此外,查詢狀態s是任意的,不必是S3。時間差分學習TD是解決這類評估問題(2)的最常用的RL算法。而且眾所周知,TD不是梯度下降。
在這項工作中,研究人員做出了三個主要貢獻。首先,通過構造證明transformers具有足夠的表達能力來在前向傳播中實現TD,這一現象我們稱為上下文TD。換句話說,transformers能夠通過上下文TD在推斷時間內解決問題(2)。超越最直接的TD,transformers還可以實現許多其他策略評估算法,包括殘差梯度(Baird,1995)、帶有資格跟蹤的TD(Sutton,1988)和平均獎勵TD(Tsitsiklis和Roy,1999)。特別地,為了實現平均獎勵TD,transformers需要使用多頭注意力和過度參數化的提示,例如,
這里,“□”充當一個虛擬占位符,在推斷期間transformers將使用它作為“記憶”。第二,通過在多個隨機生成的評估問題上訓練transformers與TD,實證地證明了在推斷中出現了上下文TD。換句話說,學習的transformer參數與我們在證明中的構造非常相符。將這種訓練方案稱為多任務TD。第三,通過展示對于單層transformer,證明了實現上下文TD所需的transformer參數在多任務TD訓練算法的不變集合的子集中,來彌合理論和實證結果之間的差距。
論文:https://arxiv.org/pdf/2405.13861
本文轉載自公眾號AIGC最前線
原文鏈接:??https://mp.weixin.qq.com/s/voNZDTww7E5ec1hUwulztw??
