終于把 Unet 算法搞懂了!!
今天給大家分享一個超強的算法模型,Unet
UNet 是一種經典的卷積神經網絡(CNN)架構,最初由 Olaf Ronneberger 等人在 2015 年提出,專為生物醫學圖像分割設計。
它的獨特之處在于其編碼器-解碼器對稱結構,能夠有效地在多尺度上提取特征并生成精確的像素級分割結果。
UNet 算法在圖像分割任務中表現優異,尤其是在需要精細邊界的場景中廣泛應用,如醫學影像分割、衛星圖像分割等。
圖片
UNet 架構
UNet 模型由兩部分組成:編碼器和解碼器,中間通過跳躍連接(Skip Connections)相連。
UNet 的設計理念是將輸入圖像經過一系列卷積和下采樣操作逐漸提取高層次特征(編碼路徑),然后通過上采樣逐步恢復原始的分辨率(解碼路徑),并將編碼路徑中對應的特征與解碼路徑進行跳躍連接(skip connection)。這種跳躍連接能夠幫助網絡結合低層次細節信息和高層次語義信息,實現精確的像素級分割。
編碼器
類似傳統的卷積神經網絡,編碼器的主要任務是逐漸壓縮輸入圖像的空間分辨率,提取更高層次的特征。
這個部分包含一系列卷積層和最大池化層(max pooling),每次池化操作都會將圖像的空間維度減少一半。
圖片
解碼器
解碼器的任務是通過逐漸恢復圖像的空間分辨率,將編碼器部分提取到的高層次特征映射回原始的圖像分辨率。
解碼器包含反卷積(上采樣)操作,并結合來自編碼器的相應特征層,以實現精細的邊界恢復。
圖片
跳躍連接
跳躍連接是 UNet 的一個關鍵創新點。
每個編碼器層的輸出特征圖與解碼器中對應層的特征圖進行拼接,形成跳躍連接。
這樣可以將編碼器中的局部信息和解碼器中的全局信息進行融合,從而提高分割結果的精度。
圖片
UNet 算法工作流程
- 輸入圖像
- 編碼階段
每個編碼塊包含兩個 3x3 卷積層(帶有 ReLU 激活函數)和一個 2x2 最大池化層,池化層用于下采樣。
經過每個編碼塊后,特征圖的空間尺寸減少一半,但通道數量翻倍。
- 瓶頸層
在網絡的最底部,這部分用來提取最深層次的特征。
- 解碼階段
每個解碼塊包含一個 2x2 轉置卷積(或上采樣操作)和兩個 3x3 卷積層(帶有 ReLU 激活函數)。
- 與編碼路徑不同的是,解碼過程中每次上采樣時,還將相應的編碼層的特征拼接(跳躍連接)到解碼層。
輸出層
最后一層通過 1x1 卷積將輸出通道數映射為類別數,用于生成分割掩碼。
最終輸出的是一個大小與輸入圖像相同的分割圖。
代碼示例
下面是一個使用 UNet 進行圖像分割的簡單示例代碼。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt
# UNet 模型定義
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.encoder1 = conv_block(1, 64)
self.encoder2 = conv_block(64, 128)
self.encoder3 = conv_block(128, 256)
self.encoder4 = conv_block(256, 512)
self.pool = nn.MaxPool2d(2)
self.bottleneck = conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder4 = conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = conv_block(128, 64)
self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, x):
# Encoder
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool(enc1))
enc3 = self.encoder3(self.pool(enc2))
enc4 = self.encoder4(self.pool(enc3))
# Bottleneck
bottleneck = self.bottleneck(self.pool(enc4))
# Decoder
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv_last(dec1))
# 創建數據集
class RandomDataset(Dataset):
def __init__(self, num_samples, image_size):
self.num_samples = num_samples
self.image_size = image_size
self.transform = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image = torch.randn(1, self.image_size, self.image_size) # 隨機生成圖像
mask = (image > 0).float() # 隨機生成掩碼
return image, mask
# 訓練模型
def train_model():
image_size = 128
batch_size = 8
num_epochs = 10
learning_rate = 1e-3
# 實例化模型、損失函數和優化器
model = UNet()
criterion = nn.BCELoss() # 使用二元交叉熵損失
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
dataset = RandomDataset(num_samples=100, image_size=image_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
for images, masks in dataloader:
outputs = model(images)
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
# 測試一個隨機樣本
test_image, test_mask = dataset[0]
model.eval()
with torch.no_grad():
prediction = model(test_image.unsqueeze(0))
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title('Input Image')
plt.imshow(test_image.squeeze().numpy(), cmap='gray')
plt.subplot(1, 3, 2)
plt.title('Ground Truth Mask')
plt.imshow(test_mask.squeeze().numpy(), cmap='gray')
plt.subplot(1, 3, 3)
plt.title('Predicted Mask')
plt.imshow(prediction.squeeze().numpy(), cmap='gray')
plt.show()
train_model()
UNet 的成功源于其有效的特征提取與恢復機制,特別是跳躍連接的設計,使得編碼過程中丟失的細節能夠通過解碼階段恢復。
UNet 在醫學圖像分割等任務上有著廣泛的應用,能夠生成高精度的像素級分割結果。