图神经网络 PyTorch Geometric 入门教程
简介
Graph Neural Networks 简称 GNN,称为图神经网络,是深度学习中近年来一个比较受关注的领域。近年来 GNN 在学术界受到的关注越来越多,与之相关的论文数量呈上升趋势,GNN 通过对信息的传递,转换和聚合实现特征的提取,类似于传统的 CNN,只是 CNN 只能处理规则的输入,如图片等输入的高、宽和通道数都是固定的,而 GNN 可以处理不规则的输入,如点云等。 可查看【GNN】万字长文带你入门 GCN。
而 PyTorch Geometric Library (简称 PyG) 是一个基于 PyTorch 的图神经网络库,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相关论文中的方法实现和常用数据集,并且提供了简单易用的接口来生成图,因此对于复现论文来说也是相当方便。用法大多数和 PyTorch 很相近,因此熟悉 PyTorch 的同学使用这个库可以很快上手。
torch_geometric.data.Data
节点和节点之间的边构成了图。所以在 PyG 中,如果你要构建图,那么需要两个要素:节点和边。PyG 提供了torch_geometric.data.Data
(下面简称Data
) 用于构建图,包括 5 个属性,每一个属性都不是必须的,可以为空。
- x: 用于存储每个节点的特征,形状是
[num_nodes, num_node_features]
。 - edge_index: 用于存储节点之间的边,形状是
[2, num_edges]
。 - pos: 存储节点的坐标,形状是
[num_nodes, num_dimensions]
。 - y: 存储样本标签。如果是每个节点都有标签,那么形状是
[num_nodes, *]
;如果是整张图只有一个标签,那么形状是[1, *]
。 - edge_attr: 存储边的特征。形状是
[num_edges, num_edge_features]
。
实际上,Data
对象不仅仅限制于这些属性,我们可以通过data.face
来扩展Data
,以张量保存三维网格中三角形的连接性。
需要注意的的是,在Data
里包含了样本的 label,这意味和 PyTorch 稍有不同。在PyTorch
中,我们重写Dataset
的__getitem__()
,根据 index 返回对应的样本和 label。在 PyG 中,我们使用的不是这种写法,而是在get()
函数中根据 index 返回torch_geometric.data.Data
类型的数据,在Data
里包含了数据和 label。
下面一个例子是未加权无向图 ( 未加权指边上没有权值 ),包括 3 个节点和 4 条边。

由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每个节点都有自己的特征。上面这个图可以使用`torch_geometric.data.Data`来表示如下:
import torch
from torch_geometric.data import Data
# 由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# 节点的特征
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
注意edge_index
中边的存储方式,有两个list
,第 1 个list
是边的起始点,第 2 个list
是边的目标节点。注意与下面的存储方式的区别。
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())
这种情况edge_index
需要先转置然后使用contiguous()
方法。关于contiguous()
函数的作用,查看 PyTorch中的contiguous。
最后再复习一遍,Data
中最基本的 4 个属性是x
、edge_index
、pos
、y
,我们一般都需要这 4 个参数。
有了Data
,我们可以创建自己的Dataset
,读取并返回Data
了。
Dataset 与 DataLoader
PyG 的 Dataset
继承自torch.utils.data.Dataset
,自带了很多图数据集,我们以TUDataset
为例,通过以下代码就可以加载数据集,root
参数设置数据下载的位置。通过索引可以访问每一个数据。
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
data = dataset[0]
在一个图中,由edge_index
和edge_attr
可以决定所有节点的邻接矩阵。PyG 通过创建稀疏的对角邻接矩阵,并在节点维度中连接特征矩阵和 label 矩阵,实现了在 mini-batch 的并行化。PyG 允许在一个 mini-batch 中的每个Data
(图) 使用不同数量的节点和边。

# 自定义 Dataset
尽管 PyG 已经包含许多有用的数据集,我们也可以通过继承torch_geometric.data.Dataset
使用自己的数据集。提供 2 种不同的Dataset
:
- InMemoryDataset:使用这个
Dataset
会一次性把数据全部加载到内存中。 - Dataset: 使用这个
Dataset
每次加载一个数据到内存中,比较常用。
我们需要在自定义的Dataset
的初始化方法中传入数据存放的路径,然后 PyG 会在这个路径下再划分 2 个文件夹:
raw_dir
: 存放原始数据的路径,一般是 csv、mat 等格式processed_dir
: 存放处理后的数据,一般是 pt 格式 ( 由我们重写process()
方法实现)。
在 PyTorch 中,是没有这两个文件夹的。下面来说明一下这两个文件夹在 PyG 中的实际意义和处理逻辑。
torch_geometric.data.Dataset
继承自torch.utils.data.Dataset
,在初始化方法 __init__()
中,会调用_download()
方法和_process()
方法。
def __init__(self, root=None, transform=None, pre_transform=None,
pre_filter=None):
super(Dataset, self).__init__()
if isinstance(root, str):
root = osp.expanduser(osp.normpath(root))
self.root = root
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self.__indices__ = None
# 执行 self._download() 方法
if 'download' in self.__class__.__dict__.keys():
self._download()
# 执行 self._process() 方法
if 'process' in self.__class__.__dict__.keys():
self._process()
_download()
方法如下,首先检查self.raw_paths
列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.download()
方法下载文件。
def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
_process()
方法如下,首先在self.processed_dir
中有pre_transform
,那么判断这个pre_transform
和传进来的pre_transform
是否一致,如果不一致,那么警告提示用户先删除self.processed_dir
文件夹。pre_filter
同理。
然后检查self.processed_paths
列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.process()
生成文件。
def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform):
warnings.warn(
'The `pre_transform` argument differs from the one used in '
'the pre-processed version of this dataset. If you really '
'want to make use of another pre-processing technique, make '
'sure to delete `{}` first.'.format(self.processed_dir))
f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter):
warnings.warn(
'The `pre_filter` argument differs from the one used in the '
'pre-processed version of this dataset. If you really want to '
'make use of another pre-fitering technique, make sure to '
'delete `{}` first.'.format(self.processed_dir))
if files_exist(self.processed_paths): # pragma: no cover
return
print('Processing...')
makedirs(self.processed_dir)
self.process()
path = osp.join(self.processed_dir, 'pre_transform.pt')
torch.save(__repr__(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
torch.save(__repr__(self.pre_filter), path)
print('Done!')
一般来说不用实现downloand()
方法。
如果你直接把处理好的 pt 文件放在了self.processed_dir
中,那么也不用实现process()
方法。
在 Pytorch 的dataset
中,我们需要实现__getitem__()
方法,根据index
返回样本和标签。在这里torch_geometric.data.Dataset
中,重写了__getitem__()
方法,其中调用了get()
方法获取数据。
def __getitem__(self, idx):
if isinstance(idx, int):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
我们需要实现的是get()
方法,根据index
返回torch_geometric.data.Data
类型的数据。
process()
方法存在的意义是原始的格式可能是 csv 或者 mat,在process()
函数里可以转化为 pt 格式的文件,这样在get()
方法中就可以直接使用torch.load()
函数读取 pt 格式的文件,返回的是torch_geometric.data.Data
类型的数据,而不用在get()
方法做数据转换操作 (把其他格式的数据转换为 torch_geometric.data.Data
类型的数据)。当然我们也可以提前把数据转换为 torch_geometric.data.Data
类型,使用 pt 格式保存在self.processed_dir
中。
DataLoader
通过torch_geometric.data.DataLoader
可以方便地使用 mini-batch。
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# 对每一个 mini-batch 进行操作
...
torch_geometric.data.Batch
继承自torch_geometric.data.Data
,并且多了一个属性:batch
。batch
是一个列向量,它将每个元素映射到每个 mini-batch 中的相应图:
batch $=\left[\begin{array}{cccccccc}0 & \cdots & 0 & 1 & \cdots & n-2 & n-1 & \cdots & n-1\end{array}\right]^{\top}$
我们可以使用它分别为每个图的节点维度计算平均的节点特征:
from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for data in loader:
data
#data: Batch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
x = scatter_mean(data.x, data.batch, dim=0)
# x.size(): torch.Size([32, 21])
关于 batching 的流程细节,你可以点击这里查看。关于scatter
方法的说明,你可以查看torch-scatter
说明文档。
Transforms
transforms
在计算机视觉领域是一种很常见的数据增强。PyG 有自己的transforms
,输出是Data
类型,输出也是Data
类型。可以使用torch_geometric.transforms.Compose
封装一系列的transforms
。我们以 ShapeNet 数据集 (包含 17000 个 point clouds,每个 point 分类为 16 个类别的其中一个) 为例,我们可以使用transforms
从 point clouds 生成最近邻图:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
还可以通过transform
在一定范围内随机平移每个点,增加坐标上的扰动,做数据增强:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6),
transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
模型训练
这里只是展示一个简单的 GCN 模型构造和训练过程,没有用到Dataset
和DataLoader
。
我们将使用一个简单的 GCN 层,并在 Cora 数据集上实验。有关 GCN 的更多内容,请查看这篇博客
。
我们首先加载数据集:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
然后定义 2 层的 GCN:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
然后训练 200 个 epochs:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
最后在测试集上验证了模型的准确率:
model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。
我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学。

图神经网络 PyTorch Geometric 入门教程的更多相关文章
- 思维导图软件MindManager新手入门教程
MindManager是一款创造.管理和交流思想的思维导图软件,其直观清晰的可视化界面和强大的功能可以快速捕捉.组织和共享思维.想法.资源和项目进程等等.MindManager新手入门教程专为新手用户 ...
- PyTorch快速入门教程七(RNN做自然语言处理)
以下内容均来自: https://ptorch.com/news/11.html word embedding也叫做word2vec简单来说就是语料中每一个单词对应的其相应的词向量,目前训练词向量的方 ...
- PyTorch 60 分钟入门教程
PyTorch 60 分钟入门教程:PyTorch 深度学习官方入门中文教程 http://pytorchchina.com/2018/06/25/what-is-pytorch/ PyTorch 6 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- 【Zigbee技术入门教程-02】一图读懂ZStack协议栈的基本架构和工作机理
[Zigbee技术入门教程-02]一图读懂ZStack协议栈的基本架构和工作机理 广东职业技术学院 欧浩源 ohy3686@foxmail.com Z-Stack协议栈是一个基于任务轮询方式的操作 ...
- 【Zigbee技术入门教程-02】一图读懂ZStack协议栈的核心思想与工作机理
[Zigbee技术入门教程-02]一图读懂ZStack协议栈的核心思想与工作机理 广东职业技术学院 欧浩源 Z-Stack协议栈是一个基于任务轮询方式的操作系统,其任务调度和资源分配由操作系统抽 ...
- PySide——Python图形化界面入门教程(四)
PySide——Python图形化界面入门教程(四) ——创建自己的信号槽 ——Creating Your Own Signals and Slots 翻译自:http://pythoncentral ...
- PySide——Python图形化界面入门教程(六)
PySide——Python图形化界面入门教程(六) ——QListView和QStandardItemModel 翻译自:http://pythoncentral.io/pyside-pyqt-tu ...
- PySide——Python图形化界面入门教程(五)
PySide——Python图形化界面入门教程(五) ——QListWidget 翻译自:http://pythoncentral.io/pyside-pyqt-tutorial-the-qlistw ...
随机推荐
- 2Ants(独立,一个个判,弹性碰撞,想象)
AntsDescriptionAn army of ants walk on a horizontal pole of length l cm, each with a constant speed ...
- Java Web(2)-jQuery下
一.jQuery的属性操作 html() 它可以设置和获取起始标签和结束标签中的内容,跟 dom 属性 innerHTML 一样. text() 它可以设置和获取起始标签和结束标签中的文本, 跟 do ...
- 5万字长文:Stream和Lambda表达式最佳实践-附PDF下载
目录 1. Streams简介 1.1 创建Stream 1.2 Streams多线程 1.3 Stream的基本操作 Matching Filtering Mapping FlatMap Reduc ...
- variable ans might not have been initialized 报错,以及初始化注意点
他是说你没有初始化而已,一般只是个warning,如果是在不能跑,那就给他初始化一下. 注意,初始化可不是任意值哈! 就比如如果要算阶乘,你初始化就不能为0. 还有如果是比较大小这类,就不要把初始化统 ...
- java图片压缩工具类(指定压缩大小)
1:先导入依赖 <!--thumbnailator图片处理--> <dependency> <groupId>net.coobird</groupId> ...
- JVM 学习笔记记录
JVM 学习笔记记录 Sun JDK 监控和故障处理工具 名称 主要作用 jps JVM Process Status Tool, 显示指定系统内所有的HotSpot虚拟机进程 jstat JVM S ...
- 使用types库修改函数
import types class ppp: pass p = ppp()#p为ppp类实例对象 def run(self): print("run函数") r = types. ...
- Python os.utime() 方法
概述 os.utime() 方法用于设置指定路径文件最后的修改和访问时间.高佣联盟 www.cgewang.com 在Unix,Windows中有效. 语法 utime()方法语法格式如下: os.u ...
- Python os.minor() 方法
概述 os.minor() 方法用于从原始的设备号中提取设备minor号码 (使用stat中的st_dev或者st_rdev field ).高佣联盟 www.cgewang.com 语法 minor ...
- MediaDevices对象
mediaDevices 是 Navigator对象的一个 只读属性,返回一个 MediaDevices 对象,该对象可提供对相机和麦克风等媒体输入设备的连接访问,也包括屏幕共享. 语法 const ...