圖像相似度估計 | 結合三元組損失的暹羅網絡
在機器學習領域,確定圖像之間的相似度在各種應用中至關重要,從檢測重復項到面部識別。解決這個問題的一個強大方法是使用暹羅網絡結合三元組損失函數。在本文中,我們將探索如何構建和訓練暹羅網絡以估計圖像相似度,并通過一個來自GitHub倉庫的實際示例進行說明。
什么是暹羅網絡?
暹羅網絡是一種包含兩個或更多相同子網絡的神經網絡架構。這些子網絡旨在為每個輸入生成特征向量,然后可以比較這些向量以估計相似度。關鍵思想是使用相同的網絡處理每個輸入,確保輸出一致且可比較。
這種架構特別適合于檢測重復項、尋找異常和面部識別等任務。在我們將要探索的實現中,網絡設置有三個相同的子網絡。每個網絡處理三張圖像中的一張:錨點圖像、正樣本(與錨點相似)和負樣本(與錨點無關)。
什么是三元組損失?
為了有效地訓練暹羅網絡,我們使用三元組損失函數。這種損失函數鼓勵網絡在特征空間中拉近錨點和正樣本的距離,同時將錨點和負樣本推得更遠。損失函數定義如下:
L(A, P, N) = max(‖f(A) — f(P)‖2 — ‖f(A) — f(N)‖2 + margin, 0)
這里,A是錨點圖像,P是正圖像,N是負圖像。函數f(x)代表網絡生成的embedding,而margin是一個小的正值,有助于確保網絡不會將所有嵌入壓縮到同一點。
設置暹羅網絡
在這次實現中,我們首先加載Totally Looks Like數據集,其中包含我們用來創建訓練網絡的三元組圖像。
1. 數據準備
使用TensorFlow的tf.data API處理數據集以創建圖像三元組。這涉及到設置一個數據管道,其中每個三元組由錨點、正樣本和負樣本圖像組成。通過調整圖像大小到目標形狀并歸一化像素值來預處理圖像。
def preprocess_image(filename):
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image
def preprocess_triplets(anchor, positive, negative):
return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)
以下是從數據集中生成的三元組示例,每行的前兩張圖像相似(錨點和正樣本),第三張不同(負樣本):
圖1:在數據準備期間生成的三元組。每行的前兩張圖像相似(錨點和正樣本),第三張不同(負樣本)
2.構建 embedding 生成器
我們暹羅網絡的核心是嵌入生成器,它使用在ImageNet上預訓練的ResNet50模型構建。通過凍結ResNet50中的大部分層的權重,并且僅微調最后幾層,我們可以利用遷移學習來減少訓練時間并提高性能。
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
# Freeze all layers until the layer conv5_block1_out
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
3.構建暹羅網絡
暹羅網絡設置為一次輸入三張圖像(錨點、正樣本和負樣本)。自定義的DistanceLayer計算錨點-正樣本對和錨點-負樣本對之間的距離。然后訓練模型以最小化相似圖像之間的距離,并最大化不相似圖像之間的距離。
class DistanceLayer(layers.Layer):
def call(self, anchor, positive, negative):
ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
return (ap_distance, an_distance)
anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))
distances = DistanceLayer()(
embedding(resnet.preprocess_input(anchor_input)),
embedding(resnet.preprocess_input(positive_input)),
embedding(resnet.preprocess_input(negative_input)),
)
siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=distances
)
4.訓練和評估
模型使用自定義訓練循環進行訓練,其中計算三元組損失并用于更新網絡的權重。仔細監控訓練過程,并通過對學習到的嵌入進行檢查來評估模型的性能。
class SiameseModel(Model):
def __init__(self, siamese_network, margin=0.5):
super(SiameseModel, self).__init__()
self.siamese_network = siamese_network
self.margin = margin
self.loss_tracker = metrics.Mean(name="loss")
def train_step(self, data):
with tf.GradientTape() as tape:
loss = self._compute_loss(data)
gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
self.optimizer.apply_gradients(
zip(gradients, self.siamese_network.trainable_weights)
)
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
def _compute_loss(self, data):
ap_distance, an_distance = self.siamese_network(data)
loss = ap_distance - an_distance
loss = tf.maximum(loss + self.margin, 0.0)
return loss
5.檢查結果
訓練完成后,我們可以通過比較錨點-正樣本對和錨點-負樣本對的嵌入之間的余弦相似度來評估網絡學習分離相似和不相似圖像的能力。
cosine_similarity = metrics.CosineSimilarity()
positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Positive similarity:", positive_similarity.numpy())
negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
print("Negative similarity:", negative_similarity.numpy())
以下是經過訓練的模型評估的三元組示例。網絡成功識別出圖像之間的相似性和差異:
圖2:經過訓練的暹羅網絡的輸出,其中每行的前兩張圖像被模型識別為相似,第三張為不同
結論
本文展示了使用三元組損失的暹羅網絡如何有效地估計圖像相似度。通過使用預訓練的ResNet50模型并微調其層,我們可以創建一個可以應用于需要相似度估計的各種任務。
完整代碼和解釋,參考:https://github.com/elcaiseri/Siamese-Network