我們一起聊聊基于 KAN、KAN卷積的軸承故障診斷模型
前言
本文基于凱斯西儲大學(CWRU)軸承數據,先經過數據預處理進行數據集的制作和加載,最后通過Pytorch實現優化的KAN模型和KAN卷積模型對故障數據的分類。
1、KAN 網絡介紹
1.1 KAN 網絡三大特征
- 數學上有據可依
- 準確性高
- 可解釋性強
1.2 傳統 MLP 的本質
多層感知機(MLPs),也稱為全連接前饋神經網絡,是深度學習模型的基礎構建塊。MLPs 的重要性不言而喻,因為它們是機器學習中用于逼近非線性函數的默認模型,其表達能力由普適逼近定理保證。
(1)容易產生梯度消失和梯度爆炸:
- 梯度消失:梯度趨近于零,網絡權重無法更新或更新的很微小,網絡訓練再久也不會有效果;
- 梯度爆炸:梯度呈指數級增長,變的非常大,然后導致網絡權重的大幅更新,使網絡變得不穩定。
(2)參數效率低:
MLP 通常使用全連接層,意味著每層的每個神經元都與前一層的所有神經元相連接,導致參數數量迅速增加,尤其是對輸入維度很高的數據;這不僅增加了計算負擔,也增加了模型過擬合的風險。
(3)可解釋性差:
盡管 MLPs 的使用普遍,但它們有著顯著的缺點。例如,在Transformer中,MLPs 幾乎消耗所有非嵌入參數,并且通常在沒有后續分析工具的情況下(相對于注意力層)不太可解釋。其可解釋性比較差,和一個黑盒模型一樣,無法探究是怎么進行學習的。
1.3 MLP 與 KAN 對比
(1)Kolmogorov-Arnold 定理:
任何一個多變量連續函數都可以表示為一些單變量函數的組合!(在數學的視角,任何問題的核心都是在擬合函數)
(2)激活函數可學習的:
神經網絡中每一層的輸入輸出都是一個線性求和的過程,所以如果沒有激活函數,那么無論你構造的神經網絡多么復雜,有多少層,最后的輸出都是輸入的線性組合,純粹的線性組合并不能夠解決更為復雜的問題。而引入激活函數之后,我們會發現常見的激活函數都是非線性的,使得神經網絡可以逼近其他的任何非線性函數。與MLP不同激活函數固定 ,而 KAN 激活函數可學習的, 是可變的!
- MLP: 激活函數固定, 輸入先相加再激活
- KAN: 激活函數可學習的,輸入先激活再相加
(3)樣條函數:
KAN 中的每層非線性函數 Ф 都采用同樣的函數結構,只是用不同的參數來控制其形狀,文章選擇了數值分析中的樣條函數 spline ,樣條理論是函數逼近的有力工具。
樣條函數是由多個多項式片段組成的函數,每個片段在相鄰節點之間定義。這些片段在節點處連接,以確保整體函數的光滑性。
b樣條曲線有一個優勢就是有明顯的幾何意義。通過砍角算法(嵌套的線性插值)可以方便的進行曲線的細分、導矢計算、曲線分割、逼近(消去節點),不僅可以方便的進行各種操作,而且精度比采用冪基函數的多項式樣條高。
(4)MLP 與 KAN 對比:
MPL 是固定的非線性激活 + 線性參數學習,KAN 則是直接對參數化的非線性激活函數的學習。KAN 實現了使用更少的節點,更小的網絡,來實現同樣的效果,甚至更優的效果!
1.4 KAN 執行過程
1.5 可解釋性
運行代碼文件中的 hellokan.ipynb 實現上述可視化過程
2 KAN 卷積(CKAN)
2.1 CKAN
最近,有研究者將 KAN 創新架構的理念擴展到卷積神經網絡,將卷積的經典線性變換更改為每個像素中可學習的非線性激活函數,提出并開源 KAN 卷積(CKAN)
KAN 卷積與卷積非常相似,但不是在內核和圖像中相應像素之間應用點積,而是對每個元素應用可學習的非線性激活函數,然后將它們相加。KAN 卷積的內核相當于 4 個輸入和 1 個輸出神經元的 KAN 線性層。
2.2 CKAN 中的參數
假設有一個 KxK 內核,對于該矩陣的每個元素,都有一個 ?,其參數計數為:gridsize + 1,? 定義為:
這為激活函數 b 提供了更多的可表達性,線性層的參數計數為 gridsize + 2。因此,KAN 卷積總共有 K^2(gridsize + 2) 個參數,而普通卷積只有 K^2。
3.3 CKAN 在軸承故障診斷中的應用
通過前面的對比實驗可以看出,基于 KAN 的卷積網絡比傳統卷積網絡在軸承故障分類任務上效果會好一些,但是訓練時間較長。后續可以考慮融合其他模塊,做進一步優化;同時基礎的 KAN 層完全可以替代分類任務中的全連接層,效果顯著,可以在其他數據集上做進一步的對比實驗。總的來說,KAN 卷積的實現是一個很有前景的想法,在軸承故障診斷任務上也存在一定的應用前景,值得我們去探索!
3 軸承故障數據的預處理
3.1 導入數據
參考之前的文章,進行故障10分類的預處理,凱斯西儲大學軸承數據10分類數據集:
train_set、val_set、test_set 均為按照7:2:1劃分訓練集、驗證集、測試集,最后保存數據
上圖是數據的讀取形式以及預處理思路
3.2 數據預處理,制作數據集
4 基于 Pytorch的 KANConv 的軸承故障診斷
4.1 定義 KANConv 分類網絡模型,設置參數,訓練模型
100個epoch,訓練集、驗證集準確率98%,用改進 KAN 卷積 網絡分類效果顯著,模型能夠充分提取軸承故障信號中的故障特征,收斂速度快,性能優越,精度高,效果明顯!(代價是運行時間比傳統CNN網絡要慢)
4.2 模型評估
準確率、精確率、召回率、F1 Score
故障十分類混淆矩陣:
本文轉載自 ??建模先鋒??,作者: 小蝸愛建模
