下一個AI前沿與革命:KAN 上
1.KAN
這種新型的網絡架構的核心思想基于由柯爾莫哥洛夫-阿諾德表示定理,它被寄予期望能夠替代多層感知器。MLP 在節點(“神經單元”)上具有固定的激活函數,而 KAN 在邊上(“權重”)具有可學習的激活函數。KAN根本沒有線性權重—每個權重參數都被參數化為一元的spline function。
大白話的意思就是:KAN中的每個激活函數不是在每個節點,而是在每條邊上。由一個一元函數(univariate function)組成,這些函數本身也是參數。意味著函數即參數,每個權重參數不再是一個單一的數值,而是一個函數。
spline,a continuous curve constructed so as to pass through a given set of points and have a certain number of continuous derivatives.大致的意思就是spline是一條連續的曲線,這條曲線能夠穿過給點的數據點,且擁有特定數目的連續導
具體什么意思,不要著急慢慢來。先來看看KAN相對于MLP的優點:
- 這種簡單的變化使得KAN在準確性和可解釋性方面優于 MLP。
- ?從理論上和經驗上來說,KAN比MLP擁有更快的神經尺度法則。
- KAN 可以直觀地可視化,并且可以輕松地與人類用戶交互。
?
上圖為兩者之間的對比,最大的區別在于KAN學習的對象是邊的激活函數,而每個節點僅僅做數值累加,KAN的多層累加有點函數套娃的意思。傳統的激活函數是什么,可以回到??初心??去看看。
2.柯爾莫哥洛夫-阿諾德表示定理
這個定理是KAN的基石,用大白話文講就是任何一個函數都可以轉化為由單個變量的函數再套一層單變量的函數。
為了讓大家更好的理解,舉個一元二次方程的例子。相信大部分的同學都能寫出來根的公式。而一元三次方程的根其實也是可以表示為a,b,c,d的函數。
那么一元四次,五次,六次呢,是不是更加復雜,關鍵是還能寫得出來么。因此這個定理的貢獻在于將高維函數簡化成多項式數量的一維函數的組合。
為了讓大家更好的理解,細究下這個定理的歷史。故事來至希爾伯特的23個問題:大致的背景是德國數學家大衛·希爾伯特于1900年在巴黎舉行的第二屆國際數學家大會上作了題為《數學問題》的演講,所提出23道最重要的數學問題。其中的第十三問題,希爾伯特希望數學界能夠證明:x7+ax3+bx2+cx+1=0這個方程式的七個解,若表成系數為a,b,c的函數,則此函數無法簡化成兩個變數的函數。
后續柯爾莫哥洛夫證明每個有多個變元的函數可用有限個三變元函數構作。阿諾爾德按這個結果繼續研究證明兩個變元已足夠。之后阿諾爾德和日本數學家志村五郎發表了論文(Superposition of algebraic functions (1976), in Mathematical Developments Arising From Hilbert's Problems)。
這些結果后來被進一步發展,推導出人工神經網絡中的通用近似定理,指人工神經網絡能近似任意連續函數。
3.KAN架構
先回到KA定理的公式:
聰明的讀者一定會發現,要是將這個做成神經網絡,是不是只有兩層非線性和一個隱藏層(2n+1),因為函數只套了一次。對的!加上一維函數可能是非光滑的,甚至是分形的(fractal),在實踐無法學習。
這個函數其實看起來復雜,理解卻是不難。第一層輸入數目為n,舉個例子,X1自身對應著2n+1個內圈函數Θq,1 (q=0...2n+1)。所以一共有n*(n+1)個內圈函數,將Θ1,1 , Θ1,2 ,Θ1,3 ,Θ1,4 , ... ,Θ1,n進行累加輸出,一共輸出2n+1個數值。第二層將2n+1個累加數值輸入外圈函數Φ,得到1個輸出。所以傳統的KA是兩層的,n->2n+1->1。
然而這次MIT做了技術突破,它們擴展了KA,提出了KAN架構。KAN架構的好處在于保留函數即參數的內核之外,將兩層約束擴展到任意可以堆疊的網絡結構。
KAN的網絡結構可以由[n0, n1, · · · , nL]這個整形數組來表達,每個數值代表著每一層的節點個數(節點執行累加的動作)。下圖為中間層,任何一層輸入,假定這個數組為[5,4,1],那么最早一層就是4*5的函數矩陣,在往下就是1*4的矩陣,最終輸出為1個數值。想想為什么?不熟悉矩陣的同學可以溫習下??這篇文章??。
最終KAN網絡的運作模式如下:從輸入不斷經過函數矩陣的變化達到最終的數值。2.7比較形象,2.8對于數學比較友好。
4.激活函數
?
激活函數的長相如2.10所示,它由殘差函數和Spline函數組合而成。w為權重,雖然它有點多余,畢竟可以被這兩個函數吸收掉,但是可以來控制整個激活數值的縮放。2.11展示了殘差函數。而Spline函數則是由B-splines構成。數學小白可以跳過B-splines函數,它其實就是分段連續的多項式曲線。
它有兩個重要參數:節點和次數。數值域被細分成節點劃分而成的多個區間。如何理解上面的公式在這里不重要,最重要的是它是一種構造曲線的方式。如此通過學習,校正激活函數以便于獲得期望的輸入和輸出。
B-splines是一個分段連續的多項式曲線,它的參數域通過節點(knots)來表示,每兩個節點之間是參數域的一段,比如一個B樣條的參數域可以表示為:[??0,??1,??2,??3,??4,??5,??6,??7,??8] ,一共9個節點;參數域分為8段:[??0,??1],[??1,??2],[??2,??3],[??3,??4],[??4,??5],...,[??7,??8] 。下面4幅圖較為直觀,可以通過不同的基函數構造出下面的曲線:
5.性能
KAN的訓練過程不在本文描述,它需要一定的技術背景,后續會另開專題詳細的解釋。為了可視化,研究人員設計一種交互式的監督學習,通過初訓練,剪枝,然后設定一些比較常見的函數,最終再次訓練參數(affine parameter)進而得到結果。事實上,它在已知函數表達式和未知函數表達式上的模擬都超過MLP。
下圖為五個函數,分別采用的KAN網絡模擬,最多不會超過4層。最后一張圖列出了它們和mlp對比,橫軸為參數。
