性能大漲20%!中科大「狀態序列頻域預測」方法:表征學習樣本效率max
強化學習算法(Reinforcement Learning, RL)的訓練過程往往需要大量與環境交互的樣本數據作為支撐。然而,現實世界中收集大量的交互樣本通常成本高昂或者難以保證樣本采集過程的安全性,例如無人機空戰訓練和自動駕駛訓練。
為了提升強化學習算法在訓練過程中的樣本效率,一些研究者們借助于表征學習(representation learning),設計了預測未來狀態信號的輔助任務,使得表征能從原始的環境狀態中編碼出與未來決策相關的特征。
基于這個思路,該工作設計了一種預測未來多步的狀態序列頻域分布的輔助任務,以捕獲更長遠的未來決策特征,進而提升算法的樣本效率。
該工作標題為State Sequences Prediction via Fourier Transform for Representation Learning,發表于NeurIPS 2023,并被接收為Spotlight。
作者列表:葉鳴軒,匡宇飛,王杰*,楊睿,周文罡,李厚強,吳楓
論文鏈接:https://openreview.net/forum?id=MvoMDD6emT
代碼鏈接:https://github.com/MIRALab-USTC/RL-SPF/
研究背景與動機
深度強化學習算法在機器人控制[1]、游戲智能[2]、組合優化[3]等領域取得了巨大的成功。但是,當前的強化學習算法仍存在「樣本效率低下」的問題,即機器人需要大量與環境交互的數據才能訓得性能優異的策略。
為了提升樣本效率,研究者們將目光投向于表征學習,希望訓得的表征能從環境的原始狀態中提取出充足且有價值的特征信息,從而提升機器人對狀態空間的探索效率。
基于表征學習的強化學習算法框架
在序列決策任務中,「長期的序列信號」相對于單步信號包含更多有利于長期決策的未來信息。啟發于這一觀點,一些研究者提出通過預測未來多步的狀態序列信號來輔助表征學習[4,5]。然而,直接預測狀態序列來輔助表征學習是非常困難的。
現有的兩類方法中,一類方法通過學習單步概率轉移模型來逐步地產生單個時刻的未來狀態,以間接預測多步的狀態序列[6,7]。但是,這類方法對所訓得的概率轉移模型的精度要求很高,因為每步的預測誤差會隨預測序列長度的增加而積累。
另一類方法通過直接預測未來多步的狀態序列來輔助表征學習[8],但這類方法需要存儲多步的真實狀態序列作為預測任務的標簽,所耗存儲量大。因此,如何有效從環境的狀態序列中提取有利于長期決策的未來信息,進而提升連續控制機器人訓練時的樣本效率是需要解決的問題。
為了解決上述問題,我們提出了一種基于狀態序列頻域預測的表征學習方法(State Sequences Prediction via Fourier Transform, SPF),其思想是利用「狀態序列的頻域分布」來顯式提取狀態序列數據中的趨勢性和規律性信息,從而輔助表征高效地提取到長期未來信息。
狀態序列中的結構性信息分析
我們從理論上證明了狀態序列存在「兩種結構性信息」,一是與策略性能相關的趨勢性信息,二是與狀態周期性相關的規律性信息。
馬爾科夫決策過程
在具體分析兩種結構性信息之前,我們先介紹產生狀態序列的馬爾科夫決策過程(Markov Decision Processes,MDP)的相關定義。
我們考慮連續控制問題中的經典馬爾可夫決策過程,該過程可用五元組 表示。其中, 為相應的狀態、動作空間, 為獎勵函數, 為環境的狀態轉移函數, 為狀態的初始分布, 為折扣因子。此外,我們用 表示策略在狀態 下的動作分布。
我們將 時刻下智能體所處的狀態記為 ,所選擇的動作記為 .智能體做出動作后,環境轉移到下一時刻狀態 并反饋給智能體獎勵 。我們將智能體與環境交互過程中所得到狀態、動作對應的軌跡記為 ,軌跡服從分布 。
強化學習算法的目標是最大化未來預期的累積回報,我們用 表示當前策略 和 環境模型 下的平均累積回報,并簡寫為 ,定義如下:
顯示了當前策略 的性能表現。
趨勢性信息
下面我們介紹狀態序列的「第一種結構性特征」,其涉及狀態序列和對應獎勵序列之間的依賴關系,能顯示出當前策略的性能趨勢。
在強化學習任務中,未來的狀態序列很大程度上決定了智能體未來采取的動作序列,并進一步決定了相應的獎勵序列。因此,未來的狀態序列不僅包含環境固有的概率轉移函數的信息,也能輔助表征捕獲反映當前策略的走向趨勢。
啟發于上述結構,我們證明了以下定理,進一步論證了這一結構性依賴關系的存在:
定理一:若獎勵函數只與狀態有關,那么對于任意兩個策略 和 ,他們的性能差異可以被這兩個策略所產生的狀態序列分布差異所控制:
上述公式中, 表示在指定策略和轉移概率函數條件下狀態序列的概率分布, 表示 范數。
上述定理表明,兩個策略的性能差異越大,其對應的兩個狀態序列的分布差異也越大。這意味著好策略和壞策略會產生出兩個差異較大的狀態序列,這進一步說明狀態序列所包含的長期結構性信息能潛在影響搜索性能優異的策略的效率。
另一方面,在一定條件下,狀態序列的頻域分布差異也能為對應的策略性能差異提供上界,具體如以下定理所示:
定理二:若狀態空間有限維且獎勵函數是與狀態有關的n次多項式,那么對于任意兩個策略 和 ,他們的性能差異可以被這兩個策略所產生的狀態序列的頻域分布差異所控制:
上述公式中, 表示由策略 所產生的狀態序列的 次方序列的傅里葉函數, 表示傅里葉函數的第 個分量。
這一定理表明狀態序列的頻域分布仍包含與當前策略性能相關的特征。
規律性信息
下面我們介紹狀態序列中存在的「第二種結構性特征」,其涉及到狀態信號之間的時間依賴性,即一段較長時期內狀態序列所表現出的規律性模式。
在許多的真實場景任務中,智能體也會表現出周期性行為,因為其環境的狀態轉移函數本身就是具有周期性的。以工業裝配機器人為例,該機器人的訓練目標是將零件組裝在一起以創造最終產品,當策略訓練達到穩定時,它就會執行一個周期性的動作序列,使其能夠有效地將零件組裝在一起。
啟發于上面的例子,我們提供了一些理論分析,證明了有限狀態空間中,當轉移概率矩陣滿足某些假設,對應的狀態序列在智能體達到穩定策略時可能表現出「漸近周期性」,具體定理如下:
定理三:對于狀態轉移矩陣為 的有限維狀態空間 ,假設 有 個循環類,對應的狀態轉移子矩陣為 。設這 個矩陣模為1的特征值個數為 ,則對于任意狀態的初始分布 ,狀態分布 呈現出周期為 的漸進周期性。
在MuJoCo任務中,策略訓練達到穩定時,智能體也會表現出周期性的運動。下圖中給出了MuJoCo任務中HalfCheetah智能體在一段時間內的狀態序列示例,可以觀察到明顯的周期性。(更多MuJoCo任務中帶周期性的狀態序列示例可參考本論文附錄第E節)
MuJoCo任務中HalfCheetah智能體在一段時間內狀態所表現出的周期性
時間序列在時域中呈現的信息相對分散,但在頻域中,序列中的規律性信息以更加集中的形式呈現。通過分析頻域中的頻率分量,我們能顯式地捕獲到狀態序列中存在的周期性特征。
方法介紹
上一部分中,我們從理論上證明狀態序列的頻域分布能反映策略性能的好壞,并且通過在頻域上分析頻率分量我們能顯式捕獲到狀態序列中的周期性特征。
啟發于上述分析,我們設計了「預測無窮步未來狀態序列傅里葉變換」的輔助任務來鼓勵表征提取狀態序列中的結構性信息。
SPF方法損失函數
下面介紹我們關于該輔助任務的建模。給定當前狀態 和動作 ,我們定義未來的狀態序列期望如下:
我們的輔助任務訓練表征去預測上述狀態序列期望的離散時間傅里葉變換(discrete-time Fourier transform, DTFT),即
上述傅里葉變換公式可改寫為如下的遞歸形式:
其中,
其中, 為狀態空間的維度, 為所預測的狀態序列傅里葉函數的離散化點的個數。
啟發于Q-learning中優化Q值網絡的TD-error損失函數[9],我們設計了如下的損失函數:
其中, 和 分別為損失函數要優化的表征編碼器(encoder)和傅里葉函數預測器(predictor)的神經網絡參數, 為存儲樣本數據的經驗池。
進一步地,我們可以證明上述的遞歸公式可以表示為一個壓縮映射:
定理四:令 表示函數族 ,并定義 上的范數為:
其中 表示矩陣 的第 行向量。我們定義映射 為
則可以證明 為一個壓縮映射。
根據壓縮映射原理,我們可以迭代地使用算子 ,使得 逼近真實狀態序列的頻域分布,且在表格型情況(tabular setting)下有收斂性保證。
此外,我們所設計的損失函數只依賴于當前時刻與下一時刻的狀態,所以無需存儲未來多步的狀態數據作為預測標簽,具有「實施簡單且存儲量低」的優點。
SPF方法算法框架
下面我們介紹本論文方法(SPF)的算法框架。
基于狀態序列頻域預測的表征學習方法(SPF)的算法框架圖
我們將當前時刻和下一時刻的狀態-動作數據分別輸入到在線(online)和目標(target)表征編碼器(encoder)中,得到狀態-動作表征數據,然后將該表征數據輸入到傅里葉函數預測器(predictor)得到當前時刻和下一時刻下的兩組狀態序列傅里葉函數預測值。通過代入這兩組傅里葉函數預測值,我們能計算出損失函數值。
我們通過最小化損失函數來優化更新表征編碼器 和傅里葉函數預測器 ,使預測器的輸出能逼近真實狀態序列的傅里葉變換,從而鼓勵表征編碼器提取出包含未來長期狀態序列的結構性信息的特征。
我們將原始狀態和動作輸入到表征編碼器中,將得到的特征作為強化學習算法中actor網絡和critic網絡的輸入,并用經典強化學習算法優化actor網絡和critic網絡。
實驗結果
(注:本節僅選取部分實驗結果,更詳細的結果請參考論文原文第6節及附錄。)
算法性能比較
我們將 SPF 方法在 MuJoCo 仿真機器人控制環境上測試,對如下 6 種方法進行對比:
- SAC:基于Q值學習的soft actor-critic算法[10],一種傳統的RL算法;
- PPO:基于策略優化的proximal policy optimization算法[11],一種傳統RL算法;
- SAC-OFE:利用預測單步未來狀態的輔助任務進行表征學習,以優化SAC算法;
- PPO-OFE:利用預測單步未來狀態的輔助任務進行表征學習,以優化PPO算法;
- SAC-SPF:利用預測無窮步狀態序列的頻域函數的輔助任務進行表征學習(我們的方法),以優化SAC算法;
- PPO-SPF:利用預測無窮步狀態序列的頻域函數的輔助任務進行表征學習(我們的方法),以優化PPO算法;
基于6種MuJoCo任務的對比實驗結果
上圖顯示了在 6 種 MuJoCo 任務中,我們所提出的SPF方法(紅線及橙線)與其他對比方法的性能曲線。結果顯示,我們所提出的方法相比于其他方法能獲得19.5%的性能提升。
消融實驗
我們對 SPF 方法的各個模塊進行了消融實驗,將本方法與不使用投影器模塊(noproj)、不使用目標網絡模塊(notarg)、改變預測損失(nofreqloss)、改變特征編碼器網絡結構(mlp,mlp_cat)時的性能表現做比較。
SPF方法應用于SAC算法的消融實驗結果圖,測試于HalfCheetah任務
可視化實驗
我們使用 SPF 方法所訓練好的預測器輸出狀態序列的傅里葉函數,并通過逆傅里葉變換恢復出的200步狀態序列,與真實的200步狀態序列進行對比。
基于傅里葉函數預測值恢復出的狀態序列示意圖,測試于Walker2d任務。其中,藍線為真實的狀態序列示意圖,5條紅線為恢復出的狀態序列示意圖,越下方的、顏色越淺的紅線表示利用越久遠的歷史狀態所恢復出的狀態序列。
結果顯示,即使用更久遠的狀態作為輸入,恢復出的狀態序列也和真實的狀態序列非常相似,這說明 SPF 方法所學習出的表征能有效編碼出狀態序列中包含的結構性信息。