聊聊 KAN、KAN 卷積結合注意力機制!
第一類 基礎線性層替換
KAN 層替換線性層 Linear:
更新關于LSTM、TCN、Transformer模型中用 KAN 層替換線性層的故障分類模型。
KAN 的準確率要優于 MLP,我們可以進一步嘗試在常規模型的最后一層線性層都替換為 KAN 層來進行對比;KAN 卷積比常規卷積準確率有略微的提升!
第二類 并行融合模型
KAN卷積、GRU并行:
故障信號同時送入并行模型,分支一經過 KAN卷積進行學習,分支二利用 GRU 提取故障時域特征,然后并行特征進行堆疊融合,來增強故障信號特征提取能力。
2.1 定義 KANConv-GRU 分類網絡模型
2.2 設置參數,訓練模型
50個epoch,訓練集、驗證集準確率97%,用改進 KANConv-GRU 并行網絡分類效果顯著,模型能夠充分提取軸承故障信號中的故障特征,收斂速度快,性能優越,精度高,效果明顯!
2.3 模型評估
準確率、精確率、召回率、F1 Score
故障十分類混淆矩陣:
第三類 結合注意力機制
3.1 KAN 結合自注意力機制:
我們創造性的提出在利用 KAN 層提取的特征作為自注意力機制的輸入,來進一步增加非線性能力,具體步驟如下:
1.輸入嵌入:
首先使用 unsqueeze 將輸入從 ([batch_size, input_dim]) 擴展為 ([batch_size, 1, input_dim]),以便兼容后續的操作。
使用 input_proj 線性層將輸入從 ([batch_size, 1, input_dim]) 映射到 ([batch_size, 1, embed_dim])。
2.查詢-鍵-值投影:
- 使用 qkv_proj 線性層將輸入映射到查詢、鍵和值的嵌入空間,結果形狀為 ([batch_size, 1, embed_dim * 3])。
3. 重塑和轉置:
- 將 qkv 重塑為 ([batch_size, 1, 3, num_heads, head_dim])。
- 然后將維度重新排列為 ([3, batch_size, num_heads, 1, head_dim])。
4.計算注意力權重和輸出:
- 通過縮放的點積計算注意力權重,并對其進行 softmax 歸一化。
- 使用注意力權重與值進行加權求和,得到注意力輸出。
5.輸出重塑和映射:
- 將注意力輸出重新排列并重塑為 ([batch_size, 1, embed_dim])。
- 使用 o_proj 線性層將自注意力機制的輸出從 ([batch_size, 1, embed_dim]) 映射回 ([batch_size, 1, input_dim])。
- 使用 squeeze 移除序列長度的維度,得到最終輸出 ([batch_size, input_dim])。
?
?
通過這種方式,輸入和輸出的維度保持一致。自注意力機制通過計算每個輸入元素與其他所有輸入元素之間的相關性(注意力分數),并利用這些相關性來加權求和,更新每個輸入元素的表示,從而捕捉到輸入序列中元素之間的依賴關系。進一步加強了 KAN 輸出信息對復雜特征的建模能力。
3.2 KAN 卷積結合通道注意力機制SENet:
KAN 卷積與卷積非常相似,但不是在內核和圖像中相應像素之間應用點積,而是對每個元素應用可學習的非線性激活函數,然后將它們相加。我們在KAN卷積的基礎上融合通道注意力機制,進一步加強了對特征的提取能力!
從對比實驗可以看出, 在軸承故障診斷任務中:
KAN卷積融合注意力機制后,效果提升明顯,后續還可以進一步嘗試與其他類型的注意力機制做融合!
本文轉載自 ??建模先鋒??,作者: 小蝸愛建模
