简介

Graph Neural Networks 简称 GNN,称为图神经网络,是深度学习中近年来一个比较受关注的领域。近年来 GNN 在学术界受到的关注越来越多,与之相关的论文数量呈上升趋势,GNN 通过对信息的传递,转换和聚合实现特征的提取,类似于传统的 CNN,只是 CNN 只能处理规则的输入,如图片等输入的高、宽和通道数都是固定的,而 GNN 可以处理不规则的输入,如点云等。 可查看【GNN】万字长文带你入门 GCN

而 PyTorch Geometric Library (简称 PyG) 是一个基于 PyTorch 的图神经网络库,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相关论文中的方法实现和常用数据集,并且提供了简单易用的接口来生成图,因此对于复现论文来说也是相当方便。用法大多数和 PyTorch 很相近,因此熟悉 PyTorch 的同学使用这个库可以很快上手。

torch_geometric.data.Data

节点和节点之间的边构成了图。所以在 PyG 中,如果你要构建图,那么需要两个要素:节点和边。PyG 提供了torch_geometric.data.Data (下面简称Data) 用于构建图,包括 5 个属性,每一个属性都不是必须的,可以为空。

  • x: 用于存储每个节点的特征,形状是[num_nodes, num_node_features]
  • edge_index: 用于存储节点之间的边,形状是 [2, num_edges]
  • pos: 存储节点的坐标,形状是[num_nodes, num_dimensions]
  • y: 存储样本标签。如果是每个节点都有标签,那么形状是[num_nodes, *];如果是整张图只有一个标签,那么形状是[1, *]
  • edge_attr: 存储边的特征。形状是[num_edges, num_edge_features]

实际上,Data对象不仅仅限制于这些属性,我们可以通过data.face来扩展Data,以张量保存三维网格中三角形的连接性。

需要注意的的是,在Data里包含了样本的 label,这意味和 PyTorch 稍有不同。在PyTorch中,我们重写Dataset__getitem__(),根据 index 返回对应的样本和 label。在 PyG 中,我们使用的不是这种写法,而是在get()函数中根据 index 返回torch_geometric.data.Data类型的数据,在Data里包含了数据和 label。

下面一个例子是未加权无向图 ( 未加权指边上没有权值 ),包括 3 个节点和 4 条边。

由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每个节点都有自己的特征。上面这个图可以使用`torch_geometric.data.Data`来表示如下:

import torch
from torch_geometric.data import Data
# 由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# 节点的特征
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index)

注意edge_index中边的存储方式,有两个list,第 1 个list是边的起始点,第 2 个list是边的目标节点。注意与下面的存储方式的区别。

import torch
from torch_geometric.data import Data edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index.t().contiguous())

这种情况edge_index需要先转置然后使用contiguous()方法。关于contiguous()函数的作用,查看 PyTorch中的contiguous

最后再复习一遍,Data中最基本的 4 个属性是xedge_indexposy,我们一般都需要这 4 个参数。

有了Data,我们可以创建自己的Dataset,读取并返回Data了。

Dataset 与 DataLoader

PyG 的 Dataset继承自torch.utils.data.Dataset,自带了很多图数据集,我们以TUDataset为例,通过以下代码就可以加载数据集,root参数设置数据下载的位置。通过索引可以访问每一个数据。

from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
data = dataset[0]

在一个图中,由edge_indexedge_attr可以决定所有节点的邻接矩阵。PyG 通过创建稀疏的对角邻接矩阵,并在节点维度中连接特征矩阵和 label 矩阵,实现了在 mini-batch 的并行化。PyG 允许在一个 mini-batch 中的每个Data (图) 使用不同数量的节点和边。

# 自定义 Dataset

尽管 PyG 已经包含许多有用的数据集,我们也可以通过继承torch_geometric.data.Dataset使用自己的数据集。提供 2 种不同的Dataset

  • InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。
  • Dataset: 使用这个Dataset每次加载一个数据到内存中,比较常用。

我们需要在自定义的Dataset的初始化方法中传入数据存放的路径,然后 PyG 会在这个路径下再划分 2 个文件夹:

  • raw_dir: 存放原始数据的路径,一般是 csv、mat 等格式
  • processed_dir: 存放处理后的数据,一般是 pt 格式 ( 由我们重写process()方法实现)。

在 PyTorch 中,是没有这两个文件夹的。下面来说明一下这两个文件夹在 PyG 中的实际意义和处理逻辑。

torch_geometric.data.Dataset继承自torch.utils.data.Dataset,在初始化方法 __init__()中,会调用_download()方法和_process()方法。

def __init__(self, root=None, transform=None, pre_transform=None,
pre_filter=None):
super(Dataset, self).__init__() if isinstance(root, str):
root = osp.expanduser(osp.normpath(root)) self.root = root
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self.__indices__ = None # 执行 self._download() 方法
if 'download' in self.__class__.__dict__.keys():
self._download()
# 执行 self._process() 方法
if 'process' in self.__class__.__dict__.keys():
self._process()

_download()方法如下,首先检查self.raw_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.download()方法下载文件。

def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return makedirs(self.raw_dir)
self.download()

_process()方法如下,首先在self.processed_dir中有pre_transform,那么判断这个pre_transform和传进来的pre_transform是否一致,如果不一致,那么警告提示用户先删除self.processed_dir文件夹。pre_filter同理。

然后检查self.processed_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.process()生成文件。

def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform):
warnings.warn(
'The `pre_transform` argument differs from the one used in '
'the pre-processed version of this dataset. If you really '
'want to make use of another pre-processing technique, make '
'sure to delete `{}` first.'.format(self.processed_dir))
f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter):
warnings.warn(
'The `pre_filter` argument differs from the one used in the '
'pre-processed version of this dataset. If you really want to '
'make use of another pre-fitering technique, make sure to '
'delete `{}` first.'.format(self.processed_dir)) if files_exist(self.processed_paths): # pragma: no cover
return print('Processing...') makedirs(self.processed_dir)
self.process() path = osp.join(self.processed_dir, 'pre_transform.pt')
torch.save(__repr__(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
torch.save(__repr__(self.pre_filter), path) print('Done!')

一般来说不用实现downloand()方法

如果你直接把处理好的 pt 文件放在了self.processed_dir中,那么也不用实现process()方法。

在 Pytorch 的dataset中,我们需要实现__getitem__()方法,根据index返回样本和标签。在这里torch_geometric.data.Dataset中,重写了__getitem__()方法,其中调用了get()方法获取数据。

def __getitem__(self, idx):
if isinstance(idx, int):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)

我们需要实现的是get()方法,根据index返回torch_geometric.data.Data类型的数据。

process()方法存在的意义是原始的格式可能是 csv 或者 mat,在process()函数里可以转化为 pt 格式的文件,这样在get()方法中就可以直接使用torch.load()函数读取 pt 格式的文件,返回的是torch_geometric.data.Data类型的数据,而不用在get()方法做数据转换操作 (把其他格式的数据转换为 torch_geometric.data.Data类型的数据)。当然我们也可以提前把数据转换为 torch_geometric.data.Data类型,使用 pt 格式保存在self.processed_dir中。

DataLoader

通过torch_geometric.data.DataLoader可以方便地使用 mini-batch。

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader:
# 对每一个 mini-batch 进行操作
...

torch_geometric.data.Batch继承自torch_geometric.data.Data,并且多了一个属性:batchbatch是一个列向量,它将每个元素映射到每个 mini-batch 中的相应图:

batch $=\left[\begin{array}{cccccccc}0 & \cdots & 0 & 1 & \cdots & n-2 & n-1 & \cdots & n-1\end{array}\right]^{\top}$

我们可以使用它分别为每个图的节点维度计算平均的节点特征:

from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True) for data in loader:
data
#data: Batch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32]) x = scatter_mean(data.x, data.batch, dim=0)
# x.size(): torch.Size([32, 21])

关于 batching 的流程细节,你可以点击这里查看。关于scatter 方法的说明,你可以查看torch-scatter说明文档

Transforms

transforms在计算机视觉领域是一种很常见的数据增强。PyG 有自己的transforms,输出是Data类型,输出也是Data类型。可以使用torch_geometric.transforms.Compose封装一系列的transforms。我们以 ShapeNet 数据集 (包含 17000 个 point clouds,每个 point 分类为 16 个类别的其中一个) 为例,我们可以使用transforms从 point clouds 生成最近邻图:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

还可以通过transform在一定范围内随机平移每个点,增加坐标上的扰动,做数据增强:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6),
transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

模型训练

这里只是展示一个简单的 GCN 模型构造和训练过程,没有用到DatasetDataLoader

我们将使用一个简单的 GCN 层,并在 Cora 数据集上实验。有关 GCN 的更多内容,请查看这篇博客

我们首先加载数据集:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')

然后定义 2 层的 GCN:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, data):
x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

然后训练 200 个 epochs:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

最后在测试集上验证了模型的准确率:

model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学

图神经网络 PyTorch Geometric 入门教程的更多相关文章

  1. 思维导图软件MindManager新手入门教程

    MindManager是一款创造.管理和交流思想的思维导图软件,其直观清晰的可视化界面和强大的功能可以快速捕捉.组织和共享思维.想法.资源和项目进程等等.MindManager新手入门教程专为新手用户 ...

  2. PyTorch快速入门教程七(RNN做自然语言处理)

    以下内容均来自: https://ptorch.com/news/11.html word embedding也叫做word2vec简单来说就是语料中每一个单词对应的其相应的词向量,目前训练词向量的方 ...

  3. PyTorch 60 分钟入门教程

    PyTorch 60 分钟入门教程:PyTorch 深度学习官方入门中文教程 http://pytorchchina.com/2018/06/25/what-is-pytorch/ PyTorch 6 ...

  4. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  5. 【Zigbee技术入门教程-02】一图读懂ZStack协议栈的基本架构和工作机理

    [Zigbee技术入门教程-02]一图读懂ZStack协议栈的基本架构和工作机理 广东职业技术学院  欧浩源  ohy3686@foxmail.com Z-Stack协议栈是一个基于任务轮询方式的操作 ...

  6. 【Zigbee技术入门教程-02】一图读懂ZStack协议栈的核心思想与工作机理

    [Zigbee技术入门教程-02]一图读懂ZStack协议栈的核心思想与工作机理 广东职业技术学院  欧浩源   Z-Stack协议栈是一个基于任务轮询方式的操作系统,其任务调度和资源分配由操作系统抽 ...

  7. PySide——Python图形化界面入门教程(四)

    PySide——Python图形化界面入门教程(四) ——创建自己的信号槽 ——Creating Your Own Signals and Slots 翻译自:http://pythoncentral ...

  8. PySide——Python图形化界面入门教程(六)

    PySide——Python图形化界面入门教程(六) ——QListView和QStandardItemModel 翻译自:http://pythoncentral.io/pyside-pyqt-tu ...

  9. PySide——Python图形化界面入门教程(五)

    PySide——Python图形化界面入门教程(五) ——QListWidget 翻译自:http://pythoncentral.io/pyside-pyqt-tutorial-the-qlistw ...

随机推荐

  1. 2Ants(独立,一个个判,弹性碰撞,想象)

    AntsDescriptionAn army of ants walk on a horizontal pole of length l cm, each with a constant speed ...

  2. Java Web(2)-jQuery下

    一.jQuery的属性操作 html() 它可以设置和获取起始标签和结束标签中的内容,跟 dom 属性 innerHTML 一样. text() 它可以设置和获取起始标签和结束标签中的文本, 跟 do ...

  3. 5万字长文:Stream和Lambda表达式最佳实践-附PDF下载

    目录 1. Streams简介 1.1 创建Stream 1.2 Streams多线程 1.3 Stream的基本操作 Matching Filtering Mapping FlatMap Reduc ...

  4. variable ans might not have been initialized 报错,以及初始化注意点

    他是说你没有初始化而已,一般只是个warning,如果是在不能跑,那就给他初始化一下. 注意,初始化可不是任意值哈! 就比如如果要算阶乘,你初始化就不能为0. 还有如果是比较大小这类,就不要把初始化统 ...

  5. java图片压缩工具类(指定压缩大小)

    1:先导入依赖 <!--thumbnailator图片处理--> <dependency> <groupId>net.coobird</groupId> ...

  6. JVM 学习笔记记录

    JVM 学习笔记记录 Sun JDK 监控和故障处理工具 名称 主要作用 jps JVM Process Status Tool, 显示指定系统内所有的HotSpot虚拟机进程 jstat JVM S ...

  7. 使用types库修改函数

    import types class ppp: pass p = ppp()#p为ppp类实例对象 def run(self): print("run函数") r = types. ...

  8. Python os.utime() 方法

    概述 os.utime() 方法用于设置指定路径文件最后的修改和访问时间.高佣联盟 www.cgewang.com 在Unix,Windows中有效. 语法 utime()方法语法格式如下: os.u ...

  9. Python os.minor() 方法

    概述 os.minor() 方法用于从原始的设备号中提取设备minor号码 (使用stat中的st_dev或者st_rdev field ).高佣联盟 www.cgewang.com 语法 minor ...

  10. MediaDevices对象

    mediaDevices 是 Navigator对象的一个 只读属性,返回一个 MediaDevices 对象,该对象可提供对相机和麦克风等媒体输入设备的连接访问,也包括屏幕共享. 语法 const ...