終于把圖神經網絡算法搞懂了!!!
今天給大家分享一個強大的算法模型,GNN。
圖神經網絡(GNN)是一類專門處理圖結構數據的深度學習模型。
在傳統的深度學習中,輸入數據通常是結構化的(如圖像、文本、時間序列等),這些數據都可以表示為一個規則的網格或序列。然而,圖數據具有更加復雜的非歐幾里得結構,節點和邊之間可能沒有固定的順序,也可能存在不同的連接模式。
GNN 通過設計一種特定的機制來學習和表示圖結構數據中的節點、邊和全圖的信息。
圖片
圖的基本組成
在討論 GNN 之前,先了解一下圖的基本構成。
- 節點(Node),圖中的基本元素,通常表示圖中實體或對象。
- 邊(Edge),連接節點之間的關系,可能是有向的或無向的。
- 鄰接關系(Adjacency),描述哪些節點之間通過邊相連。鄰接矩陣通常用于表示這種關系。
- 節點特征(Node Feature),每個節點可能有附加的屬性或特征,如社交網絡中用戶的年齡、性別等。
- 邊特征(Edge Feature),邊也可以有特征,例如在交通網絡中,邊可能表示道路的長度或交通流量。
圖片
圖神經網絡的核心思想
GNN 的核心思想是利用圖的拓撲結構,通過節點間的鄰接關系來傳播信息和進行學習。在 GNN 中,節點的表示不僅依賴于其自身的特征,還依賴于其鄰居節點的特征。
圖神經網絡的計算通常包括以下幾個步驟。
- 信息傳遞節點通過與其鄰居節點交換信息來更新自身的表示。這一過程通常通過消息傳遞機制實現,節點會將自己的特征向量傳遞給鄰居節點,鄰居節點再根據自己的特征和接收到的信息來更新自身的特征。
- 聚合每個節點會根據鄰居節點的特征進行聚合操作,常見的聚合操作包括求和、均值、最大值等。這個步驟使得每個節點不僅包含自身的信息,還融合了鄰居的信息。
- 更新聚合后的信息會與當前節點的原始特征一起傳入一個非線性函數(通常是一個神經網絡層),來更新節點的表示。
- 迭代GNN 是一個迭代過程,通常會執行多次消息傳遞和特征更新,每次迭代都會使得節點的表示更加豐富,能夠捕捉到更廣泛的上下文信息。
- 輸出層根據任務需求,最終會從節點特征或者圖特征中提取出有用的信息進行分類、回歸等任務。
GNN 任務類型
節點級任務
節點級任務主要關注圖中單個節點的預測或嵌入。它通常依賴于節點的特征及其鄰居節點的信息。
節點級任務常見的應用包括節點分類、節點嵌入等。
- 節點分類:預測每個節點的類別。
- 節點嵌入:學習每個節點的低維表示,通常用于下游任務(如聚類或分類)。
- 節點回歸:預測節點的連續值。
示例代碼:節點分類任務
假設我們有一個社交網絡圖,任務是預測每個用戶的興趣類別(例如,體育、音樂、科技等)。
我們使用 PyTorch Geometric 框架實現一個簡單的圖卷積網絡(GCN)。
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import GCNConv
# GCN模型定義
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# 第一次圖卷積
x = self.conv1(x, edge_index)
x = torch.relu(x)
# 第二次圖卷積
x = self.conv2(x, edge_index)
return x
# 假設我們有一個圖,包含節點特征和邊的連接關系
# 節點特征: x, 鄰接矩陣: edge_index
x = torch.randn(100, 16) # 100個節點,16維特征
edge_index = torch.randint(0, 100, (2, 500)) # 500條邊
# 目標標簽:節點的類別(假設有10個類別)
y = torch.randint(0, 10, (100,))
# 創建GCN模型
model = GCN(in_channels=16, hidden_channels=32, out_channels=10)
# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 訓練過程
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(x, edge_index) # 獲取節點的分類輸出
loss = criterion(out, y) # 計算損失
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')
節點級任務的應用場景
- 社交網絡分析:預測社交網絡中每個用戶的興趣標簽。
- 生物信息學:預測基因、蛋白質的功能類別。
- 推薦系統:預測用戶或物品的類別或偏好。
邊級任務
邊級任務關注圖中節點間的關系。
邊級任務常見的應用包括鏈接預測、邊分類等。
- 鏈接預測:預測兩個節點之間是否存在邊,或預測未觀察到的潛在邊。
- 邊分類:對圖中的邊進行分類任務,如判斷兩個節點之間的關系類型。
- 邊回歸:預測邊的連續值,如邊的權重或相似度。
示例代碼:鏈接預測任務
在鏈接預測任務中,我們預測圖中節點對是否存在邊。
通過GNN學習到的節點表示,可以計算節點對之間的相似度,進而預測鏈接。
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
# GCN模型定義(用于鏈接預測)
class GCNLinkPrediction(nn.Module):
def __init__(self, in_channels, hidden_channels):
super(GCNLinkPrediction, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = torch.relu(x)
x = self.conv2(x, edge_index)
return x
# 假設我們有一個圖,包含節點特征和邊的連接關系
x = torch.randn(100, 16) # 100個節點,16維特征
edge_index = torch.randint(0, 100, (2, 500)) # 500條邊
# 創建GCN模型
model = GCNLinkPrediction(in_channels=16, hidden_channels=32)
# 定義損失函數和優化器
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 訓練過程
model.train()
for epoch in range(200):
optimizer.zero_grad()
# 前向傳播,得到節點嵌入表示
out = model(x, edge_index)
# 負采樣生成不存在的邊
neg_edge_index = negative_sampling(edge_index, num_nodes=100, num_neg_samples=edge_index.size(1))
# 獲取真實邊和負邊
pos_out = out[edge_index[0]] * out[edge_index[1]]
neg_out = out[neg_edge_index[0]] * out[neg_edge_index[1]]
# 計算損失
pos_loss = torch.sigmoid(pos_out).sum()
neg_loss = torch.sigmoid(neg_out).sum()
loss = -(pos_loss - neg_loss)
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')
邊級任務的應用場景
- 社交網絡分析:預測用戶之間是否會建立新的聯系(如好友推薦)。
- 推薦系統:預測用戶與物品之間的潛在關系(例如,是否購買)。
- 知識圖譜:預測實體之間的關系(如“巴黎”和“法國”的“首都”關系)。
圖級任務
圖級任務關注整個圖的預測或表示。
任務的目標是將整個圖映射到一個類別或一個值,常見的任務包括圖分類、圖回歸等。
- 圖分類:對整個圖進行分類,常用于生物分子分類、文檔分類等。
- 圖回歸:預測整個圖的連續值,如預測圖的某種特性(例如分子的毒性)。
GNN的類型
不同的 GNN 變種在消息傳遞、聚合和更新機制上有所不同。
以下是一些常見的GNN模型:
- GCNGCN 是最經典的圖卷積網絡,它借鑒了卷積神經網絡的思想,通過對鄰居節點的特征進行加權平均來更新節點表示。GCN 使用了圖的鄰接矩陣來定義節點間的信息傳播規則。
- GATGAT 引入了注意力機制,在信息傳遞的過程中,給不同的鄰居節點分配不同的權重(即鄰接節點的影響力不同)。這種方式使得 GAT 能夠更靈活地處理圖中節點的異質性。
- GraphSAGEGraphSAGE 通過對每個節點的鄰居進行采樣來減少計算開銷,而不是直接使用全部鄰居節點。
GNN的應用場景
圖神經網絡在很多領域得到了廣泛應用
- 社交網絡分析在社交網絡中,節點表示人或社交媒體賬戶,邊表示他們之間的互動關系。GNN可以用來進行用戶推薦、社交圈分析、輿情分析等任務。
- 化學分子建模在化學中,分子結構可以用圖表示,其中節點代表原子,邊代表原子之間的化學鍵。GNN可以用來預測分子的性質、藥物設計等。
- 知識圖譜知識圖譜是包含實體和關系的大型圖結構,GNN 可以用于關系預測、實體鏈接等任務。
- 推薦系統在推薦系統中,用戶和物品可以構成圖結構,GNN 可以用于用戶偏好預測、物品推薦等。
- 自然語言處理在文本中,詞語之間的關系可以通過圖表示,GNN 可以用來進行句子理解、語義分析等任務。