新的PyTorch圖神經(jīng)網(wǎng)絡(luò)庫,快了14倍:LeCun盛贊,GitHub 2000星
本文經(jīng)AI新媒體量子位(公眾號ID:QbitAI)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請聯(lián)系出處。
“CNN已老,GNN當(dāng)立!”
當(dāng)科學(xué)家們發(fā)現(xiàn),圖神經(jīng)網(wǎng)絡(luò) (GNN) 能搞定傳統(tǒng)CNN處理不了的非歐數(shù)據(jù),從前深度學(xué)習(xí)解不開的許多問題都找到了鑰匙。
如今,有個圖網(wǎng)絡(luò)PyTorch庫,已在GitHub摘下2000多星,還被CNN的爸爸Yann LeCun翻了牌:
它叫PyTorch Geometric,簡稱PyG,聚集了26項(xiàng)圖網(wǎng)絡(luò)研究的代碼實(shí)現(xiàn)。
這個庫還很快,比起前輩DGL圖網(wǎng)絡(luò)庫,PyG***可以達(dá)到它的15倍速度。
應(yīng)有盡有的庫
要跑結(jié)構(gòu)不規(guī)則的數(shù)據(jù),就用PyG吧。不管是圖形 (Graphs),點(diǎn)云 (Point Clouds) 還是流形(Manifolds) 。
這是一個豐盛的庫:許多模型的PyTorch實(shí)現(xiàn),各種有用的轉(zhuǎn)換 (Transforms) ,以及大量常見的benchmark數(shù)據(jù)集,應(yīng)有盡有。
說到實(shí)現(xiàn),包括Kipf等人的圖卷積網(wǎng)絡(luò) (GCN) 和Bengio實(shí)驗(yàn)室的圖注意力網(wǎng)絡(luò) (GAT) 在內(nèi),2017-2019年各大頂會的 (至少) 26項(xiàng)圖網(wǎng)絡(luò)研究,這里都能找到快速實(shí)現(xiàn)。
到底能多快?PyG的兩位作者用英偉達(dá)GTX 1080Ti做了實(shí)驗(yàn)。
對手DGL,也是圖網(wǎng)絡(luò)庫:
在四個數(shù)據(jù)集里,PyG全部比DGL跑得快。最懸殊的一場比賽,是在Cora數(shù)據(jù)集上運(yùn)行GAT模型:跑200個epoch,對手耗時33.4秒,PyG只要2.2秒,相當(dāng)于對方速度的15倍。
每個算法的實(shí)現(xiàn),都支持了CPU計(jì)算和GPU計(jì)算。
食用方法
庫的作者,是兩位德國少年,來自多特蒙德工業(yè)大學(xué)。
△ 其中一位
他們說,有了PyG,做起圖網(wǎng)絡(luò)就像一陣微風(fēng)。
你看,實(shí)現(xiàn)一個邊緣卷積層 (Edge Convolution Layer) 只要這樣而已:
- 1import torch
- 2from torch.nn import Sequential as Seq, Linear as Lin, ReLU
- 3from torch_geometric.nn import MessagePassing
- 4
- 5class EdgeConv(MessagePassing):
- 6 def __init__(self, F_in, F_out):
- 7 super(EdgeConv, self).__init__()
- 8 self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out))
- 9
- 10 def forward(self, x, edge_index):
- 11 # x has shape [N, F_in]
- 12 # edge_index has shape [2, E]
- 13 return self.propagate(aggr='max', edge_index=edge_index, x=x) # shape [N, F_out]
- 14
- 15 def message(self, x_i, x_j):
- 16 # x_i has shape [E, F_in]
- 17 # x_j has shape [E, F_in]
- 18 edge_features = torch.cat([x_i, x_j - x_i], dim=1) # shape [E, 2 * F_in]
- 19 return self.mlp(edge_features) # shape [E, F_out]
安裝之前確認(rèn)一下,至少要有PyTorch 1.0.0;再確認(rèn)一下cuda/bin在$PATH里,cuda/include在$CPATH里:
- 1$ python -c "import torch; print(torch.__version__)"
- 2>>> 1.0.0
- 3
- 4$ echo $PATH
- 5>>> /usr/local/cuda/bin:...
- 6
- 7$ echo $CPATH
- 8>>> /usr/local/cuda/include:...
然后,就開始各種pip install吧。
PyG項(xiàng)目傳送門:
https://github.com/rusty1s/pytorch_geometric
PyG主頁傳送門:
https://rusty1s.github.io/pytorch_geometric/build/html/index.html
PyG論文傳送門:
https://arxiv.org/pdf/1903.02428.pdf