空間變換網(wǎng)絡(luò)簡單介紹
作為谷歌Summer of Code項目的一部分,我要實現(xiàn)的第一個模型是空間變壓器網(wǎng)絡(luò)。空間變壓器網(wǎng)絡(luò)(STN)是一個可學(xué)習的模塊,可以放置在卷積神經(jīng)網(wǎng)絡(luò)(CNN)中,有效地增加空間不變性??臻g不變性是指模型對圖像的空間變換如旋轉(zhuǎn)、平移和縮放不變性。不變性是指即使輸入被變換或輕微修改,模型也能識別和識別特征的能力。空間變壓器可以放置到CNN中,以完成各種任務(wù)。圖像分類就是一個例子。假設(shè)任務(wù)是對手寫數(shù)字進行分類,每個樣本中數(shù)字的位置、大小和方向變化顯著。一個空間轉(zhuǎn)換器將提取、變換和縮放樣本中感興趣的區(qū)域?,F(xiàn)在CNN可以完成分類的任務(wù)。

空間變壓器網(wǎng)絡(luò)由3個主要組成部分組成:
(i) 定位網(wǎng)絡(luò):該網(wǎng)絡(luò)以一個batch的圖像的四維張量表示(寬度x高度x通道x Batch_Size)作為輸入。它是一個簡單的神經(jīng)網(wǎng)絡(luò),有幾個卷積層和幾個dense層。將變換參數(shù)預(yù)測為輸出。這些參數(shù)決定了輸入必須旋轉(zhuǎn)的角度、要完成的平移量以及聚焦于輸入特征圖中感興趣的區(qū)域所需的比例因子。
(ii) 采樣網(wǎng)格生成器:對batch中每幅圖像使用定位網(wǎng)絡(luò)預(yù)測的變換參數(shù),其形式為大小為2×3的仿射變換矩陣。仿射變換是一種保留點、直線和平面的變換。經(jīng)過仿射變換后,平行線保持平行。旋轉(zhuǎn)、縮放和平移都是仿射變換。

這里,T是這個仿射變換,A是表示仿射變換的矩陣。θ11, θ12, θ21, θ22被用來確定圖像旋轉(zhuǎn)的角度。θ13, θ23分別確定了圖像沿寬度和高度的平移量。因此,我們得到了一個轉(zhuǎn)換索引的采樣網(wǎng)格。
(iii) 變換后索引上的雙線性插值:現(xiàn)在圖像的索引和坐標軸已經(jīng)進行了仿射變換。它的像素移動了。例如,一個點(1,1)在軸逆時針旋轉(zhuǎn)45度后變成(√2,0),因此要找到變換點處的像素值,我們需要使用四個最接近的像素值進行雙線性插值。

為了找到點(x, y)上的像素值,我們?nèi)?個最近的點,如上圖所示。其中,floor(x)表示最大整數(shù)函數(shù),ceil(x)表示ceiling函數(shù)。線性插值必須在x和y兩個方向上完成。因此,這個函數(shù)返回完全轉(zhuǎn)換后的圖像,并在轉(zhuǎn)換索引處使用適當?shù)南袼刂怠?/p>
純Julia實現(xiàn)空間變壓器網(wǎng)絡(luò)的代碼可以在這里找到:https://github.com/thebhatman/Spatial-Transformer-Network/blob/master/src/stn.jl。我在一些圖像上測試了我的空間轉(zhuǎn)換器模塊的功能。下面是轉(zhuǎn)換函數(shù)輸出的一些示例圖像。左邊的圖像是轉(zhuǎn)換器模塊的輸入,右邊的圖像是輸出。
- 放大感興趣的區(qū)域

- 對人臉進行放大并旋轉(zhuǎn)45度。

- 對圖像沿著寬度平移,移到中心。

從上面的例子可以清楚地看出,空間轉(zhuǎn)換器模塊能夠執(zhí)行任何類型的仿射變換。在實現(xiàn)過程中,我花了很多時間來理解數(shù)組的reshape、permutedims和concatenation是如何工作的,因為當我使用這些函數(shù)時,很難調(diào)試像素和索引是如何移動的。在STN實現(xiàn)過程中,調(diào)試插值和圖像索引是最耗費時間和最令人沮喪的部分。
現(xiàn)在,我計劃使用一個CNN來訓(xùn)練這個空間轉(zhuǎn)換器模塊,以便對一個雜亂和扭曲的MNIST數(shù)據(jù)集進行手寫數(shù)字分類??臻g變壓器將能夠增加CNN的空間不變性,因此期望即使在數(shù)字被平移、旋轉(zhuǎn)或縮放時也能給出良好的分類結(jié)果。