AI「領悟」有理論解釋了!谷歌:兩種腦回路內部競爭,訓練久了突然不再死記硬背
谷歌PAIR團隊不久前撰文介紹了AI的“領悟” (Grokking)現象——
訓練久了突然不再死記硬背,而是學會舉一反三,有了泛化能力。
不出一個月,另一只團隊(主要成員來自DeepMind)表示,已經給出一個通用理論解釋——
領悟又稱延遲泛化,與AI內部兩種“腦回路”的競爭有關。
對此,有學者評價“我們需要更多這種對深度學習物理規律的研究,而不是去優化煉金術。”
AI的兩種腦回路
在先前的研究中,發現在“領悟”現象的作用下,就算只有5-24個神經元的模型也能擁有泛化能力。
新研究沿用了這種構建最小示例,以及大量做可視化的方法。
基于OpenAI在2020年一項對神經網絡內部機制之間相互作用的研究,團隊假設并驗證了模型內部有兩種算法回路(Circuits)。
- 記憶回路Cmem,訓練時表現很好,但測試時表現不佳。
- 泛化回路Cgen,訓練和測試階段表現都好。
通過改變數據集的大小和權重衰減的強度做實驗來觀察。
當訓練數據集增大時,Cmem回路的參數范數也更大,也就是在靠記憶的方式去存儲訓練集需要的信息量。
但Cgen的參數范數不隨訓練集大小變化,也就是獲得了類似“舉一反三”的泛化能力
那么,在什么條件下模型會發生整體的“領悟”現象呢?
來自兩種回路的之間競爭。
在訓練初期,直接死記硬背的速度更快,Cmem占據上風。
但隨著數據的增加,在梯度下降的作用下效率更高的Cgen會被加強。
也就是說,存在兩種不同的回路、他們之間有效率差和學習速度差是導致領悟發生的三大要素。
重新思考泛化
在更進一步的實驗中,團隊還根據這個理論成功演示了在一定條件下,已經“領悟”的模型也可以退化,出現“逆領悟”。
在新的小數據集上繼續訓練已領悟的模型時,測試精度突然變差,也就是在泛化之后的過擬合。
也可以精心調整出一個“半領悟”狀態。
當數據集的大小剛好在一個臨界值,讓Cmem和Cgen的效率相當,只對部分測試精度出現延遲泛化。
團隊認為,這種基于回路效率的分析為理解神經網絡的泛化提供了一種新的視角。
同時也提出了一些后續研究方向。
如為什么領悟所需的時間隨數據集大小的減小呈超指數級增長?為什么Cgen回路的學習速度慢?為什么在沒有權重衰減的情況下也會發生grokking?為什么在典型的機器學習訓練中沒有領悟現象?……
評論區有學者認為,研究這些基礎問題并不需要成千上萬塊H100。
GPU貧民也有機會為整個領域做出貢獻。
論文地址:https://arxiv.org/abs/2309.02390