Grad-CAM的詳細介紹和Pytorch代碼實現
Grad-CAM (Gradient-weighted Class Activation Mapping) 是一種可視化深度神經網絡中哪些部分對于預測結果貢獻最大的技術。它能夠定位到特定的圖像區域,從而使得神經網絡的決策過程更加可解釋和可視化。
Grad-CAM 的基本思想是,在神經網絡中,最后一個卷積層的輸出特征圖對于分類結果的影響最大,因此我們可以通過對最后一個卷積層的梯度進行全局平均池化來計算每個通道的權重。這些權重可以用來加權特征圖,生成一個 Class Activation Map (CAM),其中每個像素都代表了該像素區域對于分類結果的重要性。
相比于傳統的 CAM 方法,Grad-CAM 能夠處理任意種類的神經網絡,因為它不需要修改網絡結構或使用特定的層結構。此外,Grad-CAM 還可以用于對特征的可視化,以及對網絡中的一些特定層或單元進行分析。
在Pytorch中,我們可以使用鉤子 (hook) 技術,在網絡中注冊前向鉤子和反向鉤子。前向鉤子用于記錄目標層的輸出特征圖,反向鉤子用于記錄目標層的梯度。在本篇文章中,我們將詳細介紹如何在Pytorch中實現Grad-CAM。
加載并查看預訓練的模型
為了演示Grad-CAM的實現,我將使用來自Kaggle的胸部x射線數據集和我制作的一個預訓練分類器,該分類器能夠將x射線分類為是否患有肺炎。
首先我們看看這個模型的架構。就像前面提到的,我們需要識別最后一個卷積層,特別是它的激活函數。這一層表示模型學習到的最復雜的特征,它最有能力幫助我們理解模型的行為,下面是我們這個演示模型的代碼:
模型3通道接收256x256的圖片。它期望輸入為[batch size, 3,256,256]。每個ResNet塊以一個ReLU激活函數結束。對于我們的目標,我們需要選擇最后一個ResNet塊。
在Pytorch中,我們可以很容易地使用模型的屬性進行選擇。
Pytorch的鉤子函數
Pytorch有許多鉤子函數,這些函數可以處理在向前或后向傳播期間流經模型的信息。我們可以使用它來檢查中間梯度值,更改特定層的輸出。
在這里,我們這里將關注兩個方法:
該方法在模塊上注冊了一個后向傳播的鉤子,當調用backward()方法時,鉤子函數將會運行。后向鉤子函數接收模塊本身的輸入、相對于層的輸入的梯度和相對于層的輸出的梯度
它返回一個torch.utils.hooks.RemovableHandle,可以使用這個返回值來刪除鉤子。我們在后面會討論這個問題。
這與前一個非常相似,它在前向傳播中后運行,這個函數的參數略有不同。它可以讓你訪問層的輸出:
它的返回也是torch.utils.hooks.RemovableHandle
向模型添加鉤子函數
為了計算Grad-CAM,我們需要定義后向和前向鉤子函數。這里的目標是關于最后一個卷積層的輸出的梯度,需要它的激活,即層的激活函數的輸出。鉤子函數會在推理和向后傳播期間為我們提取這些值。
在定義了鉤子函數和存儲激活和梯度的變量之后,就可以在感興趣的層中注冊鉤子,注冊的代碼如下:
檢索需要的梯度和激活
現在已經為模型設置了鉤子函數,讓我們加載一個圖像,計算gradcam。
為了進行推理,我們還需要對其進行預處理:
現在就可以進行前向傳播了:
鉤子函數的返回如下:
得到了梯度和激活變量后就可以生成熱圖:
計算Grad-CAM
為了計算Grad-CAM,我們將原始論文公式進行一些簡單的修改:
結果如下:
得到的激活包含1024個特征映射,這些特征映射捕獲輸入圖像的不同方面,每個方面的空間分辨率為8x8。通過鉤子獲得的梯度表示每個特征映射對最終預測的重要性。通過計算梯度和激活的元素積可以獲得突出顯示圖像最相關部分的特征映射的加權和。通過計算加權特征圖的全局平均值,可以得到一個單一的熱圖,該熱圖表明圖像中對模型預測最重要的區域。這就是Grad-CAM,它提供了模型決策過程的可視化解釋,可以幫助我們解釋和調試模型的行為。
但是這個圖能代表什么呢?我們將他與圖片進行整合就能更加清晰的可視化了。
結合原始圖像和熱圖
下面的代碼將原始圖像和我們生成的熱圖進行整合顯示:
這樣看是不是就理解多了。由于它是一個正常的x射線結果,所以并沒有什么需要特殊說明的。
再看這個例子,這個結果中被標注的是肺炎。Grad-CAM能準確顯示出醫生為確定是否患有肺炎而必須檢查的胸部x光片區域。也就是說我們的模型的確學到了一些東西(紅色區域再肺部附近)
刪除鉤子
要從模型中刪除鉤子,只需要在返回句柄中調用remove()方法。
總結
這篇文章可以幫助你理清Grad-CAM 是如何工作的,以及如何用Pytorch實現它。因為Pytorch包含了強大的鉤子函數,所以我們可以在任何模型中使用本文的代碼。