环境配置与PyG中图与图数据集的表示和使用

一、引言

PyTorch Geometric (PyG)是面向几何深度学习的PyTorch的扩展库,几何深度学习指的是应用于图和其他不规则、非结构化数据的深度学习。基于PyG库,我们可以轻松地根据数据生成一个图对象,然后很方便的使用它;我们也可以容易地为一个图数据集构造一个数据集类,然后很方便的将它用于神经网络。

通过此节的实践内容,我们将

  1. 首先学习程序运行环境的配置
  2. 接着学习PyG中图数据的表示及其使用,即学习PyG中Data类。
  3. 最后学习PyG中图数据集的表示及其使用,即学习PyG中Dataset类。

二、环境配置

  1. 使用nvidia-smi命令查询显卡驱动是否正确安装

  1. 安装正确版本的pytorch和cudatoolkit,此处安装1.8.1版本的pytorch和11.1版本的cudatoolkit

    1. conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia
    2. 确认是否正确安装,正确的安装应出现下方的结果
    $ python -c "import torch; print(torch.__version__)"
    # 1.8.1
    $ python -c "import torch; print(torch.version.cuda)"
    # 11.1
  2. 安装正确版本的PyG

    pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-geometric

其他版本的安装方法以及安装过程中出现的大部分问题的解决方案可以在Installation of of PyTorch Geometric页面找到。

三、Data类——PyG中图的表示及其使用

Data对象的创建

Data类的官方文档为torch_geometric.data.Data

通过构造函数

Data类的构造函数

class Data(object):

    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, **kwargs):
r"""
Args:
x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes, num_node_features]`
edge_index (LongTensor, optional): 边索引矩阵,大小为`[2, num_edges]`,第0行为尾节点,第1行为头节点,头指向尾
edge_attr (Tensor, optional): 边属性矩阵,大小为`[num_edges, num_edge_features]`
y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是边的标签) """
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y for key, item in kwargs.items():
if key == 'num_nodes':
self.__num_nodes__ = item
else:
self[key] = item

edge_index的每一列定义一条边,其中第一行为边起始节点的索引,第二行为边结束节点的索引。这种表示方法被称为COO格式(coordinate format),通常用于表示稀疏矩阵。PyG不是用稠密矩阵\(\mathbf{A} \in \{ 0, 1 \}^{|\mathcal{V}| \times |\mathcal{V}|}\)来持有邻接矩阵的信息,而是用仅存储邻接矩阵\(\mathbf{A}\)中非\(0\)元素的稀疏矩阵来表示图。

通常,一个图至少包含x, edge_index, edge_attr, y, num_nodes5个属性,当图包含其他属性时,我们可以通过指定额外的参数使Data对象包含其他的属性

graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, num_nodes=num_nodes, other_attr=other_attr)

dict对象为Data对象

我们也可以将一个dict对象转换为一个Data对象

graph_dict = {
'x': x,
'edge_index': edge_index,
'edge_attr': edge_attr,
'y': y,
'num_nodes': num_nodes,
'other_attr': other_attr
}
graph_data = Data.from_dict(graph_dict)

from_dict是一个类方法:

@classmethod
def from_dict(cls, dictionary):
r"""Creates a data object from a python dictionary."""
data = cls()
for key, item in dictionary.items():
data[key] = item return data

注意graph_dict中属性值的类型与大小的要求与Data类的构造函数的要求相同。

Data对象转换成其他类型数据

我们可以将Data对象转换为dict对象:

def to_dict(self):
return {key: item for key, item in self}

或转换为namedtuple

def to_namedtuple(self):
keys = self.keys
DataTuple = collections.namedtuple('DataTuple', keys)
return DataTuple(*[self[key] for key in keys])

获取Data对象属性

x = graph_data['x']

设置Data对象属性

graph_data['x'] = x

获取Data对象包含的属性的关键字

graph_data.keys()

对边排序并移除重复的边

graph_data.coalesce()

Data对象的其他性质

我们通过观察PyG中内置的一个图来查看Data对象的性质:

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
data = dataset[0] # Get the first graph object.
print(data)
print('==============================================================') # 获取图的一些信息
print(f'Number of nodes: {data.num_nodes}') # 节点数量
print(f'Number of edges: {data.num_edges}') # 边数量
print(f'Number of node features: {data.num_node_features}') # 节点属性的维度
print(f'Number of node features: {data.num_features}') # 同样是节点属性的维度
print(f'Number of edge features: {data.num_edge_features}') # 边属性的维度
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') # 平均节点度
print(f'if edge indices are ordered and do not contain duplicate entries.: {data.is_coalesced()}') # 是否边是有序的同时不含有重复的边
print(f'Number of training nodes: {data.train_mask.sum()}') # 用作训练集的节点
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}') # 用作训练集的节点的数量
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}') # 此图是否包含孤立的节点
print(f'Contains self-loops: {data.contains_self_loops()}') # 此图是否包含自环的边
print(f'Is undirected: {data.is_undirected()}') # 此图是否是无向图

四、Dataset类——PyG中图数据集的表示及其使用

PyG内置了大量常用的基准数据集,接下来我们以PyG内置的Planetoid数据集为例,来学习PyG中图数据集的表示及使用

Planetoid数据集类的官方文档为torch_geometric.datasets.Planetoid

生成数据集对象并分析数据集

如下方代码所示,在PyG中生成一个数据集是简单直接的。在第一次生成PyG内置的数据集时,程序首先下载原始文件,然后将原始文件处理成包含Data对象的Dataset对象并保存到文件。

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/dataset/Cora', name='Cora')
# Cora() len(dataset)
# 1 dataset.num_classes
# 7 dataset.num_node_features
# 1433

分析数据集中样本

可以看到该数据集只有一个图,包含7个分类任务,节点的属性为1433维度。

data = dataset[0]
# Data(edge_index=[2, 10556], test_mask=[2708],
# train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708]) data.is_undirected()
# True data.train_mask.sum().item()
# 140 data.val_mask.sum().item()
# 500 data.test_mask.sum().item()
# 1000

现在我们看到该数据集包含的唯一的图,有2708个节点,节点特征为1433维,有10556条边,有140个用作训练集的节点,有500个用作验证集的节点,有1000个用作测试集的节点。PyG内置的其他数据集,请小伙伴一一试验,以观察不同数据集的不同。

数据集的使用

假设我们定义好了一个图神经网络模型,其名为Net。在下方的代码中,我们展示了节点分类图数据集在训练过程中的使用。

model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

结语

通过此实践环节,我们学习了程序运行环境的配置PyG中Data对象的生成与使用、以及PyG中Dataset对象的表示和使用。此节内容是图神经网络实践的基础,所涉及的内容是最常用、最基础的,在后面的内容中我们还将学到复杂Data类的构建,和复杂Dataset类的构建。

作业

  • 请通过继承Data类实现一个类专门用于表示“机构-作者-论文”的网络。该网络包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法。

参考资料

图神经网络-环境配置与PyG库的更多相关文章

  1. 代理上网环境配置docker私有库

    最后更新时间:2018年12月27日 Docker使用代理上网去 pull 各类 images,需要做如下配置: 创建目录: /etc/systemd/system/docker.service.d ...

  2. 美图WEB开放平台环境配置

    平台环境配置 1.1.设置crossdomain.xml 下载crossdomain.xml文件,把解压出来的crossdomain.xml文件放在您保存图片或图片来源的服务器根目录下,比如: htt ...

  3. 在vc中使用xtremetoolkit界面库-----安装及环境配置

    近期想用一下xtremetoolkitPro界面库.网上的使用教程资源也不多,当中着实遇到了很多的困难,毕竟是首次使用. 首先当然是配置发开环境了: 我使用的是vc6.0+xtremetoolkitP ...

  4. cocos2dx 3.0 学习笔记 引用cocostudio库 的环境配置

    cocostudio创建UI并应用时须要引用cocostudio库,须要额外的环境配置: 之前已经搭配好了基础的开发环境,包含 1) JDK 2) Python 2.7 3) ant 4) visua ...

  5. 乌班图18.04 LTS 版LAMP环境配置记录

    -- 2018.06.07 -- liujunhang lamp 环境包括:Apache服务器.php.Mysql数据库,linux服务器架构在虚拟机中.Tip:在进行环境配置之前最好进行镜像存储.1 ...

  6. 【Linux开发】【Qt开发】配置tslibs触摸屏库环境设置调试对应的设备挂载点

    [Linux开发][Qt开发]配置tslibs触摸屏库环境设置调试对应的设备挂载点 标签(空格分隔): [Linux开发] [Qt开发] 比如: cat /dev/input/mice cat /de ...

  7. PHP配置环境中开启GD库

    下配置好的PHP环境中,GD库不像windows那样可以直接用,而是默认关闭,需要把它打开,去到php.ini文件中 找到php_gd2.dll把分号去掉即可.(注:GD库跟绘制二维码等有关)

  8. Windows Server 2008 R2 IIS7.5下PHP、MySQL快速环境配置【图】

    众所周知,win平台的服务器版本默认是不能运行php的,需要对服务器进行环境配置. 而许多朋友纠结如何配置,在百度上搜索出的教程一大堆,基本步骤复杂,新手配置容易出错. 今天,邹颖峥教大家一种快速配置 ...

  9. cisco路由器 三层交换机简单环境配置实例(图)

    出处:http://www.jb51.NET/softjc/56600.html cisco路由器&三层交换机简单环境配置实例 一.网络拓扑图: 二.配置命令: 1.路由器的配置: inter ...

随机推荐

  1. 论文笔记:(NIPS2018)PointCNN: Convolution On X-Transformed Points

    目录 摘要 一.2D卷积应用在点云上存在的问题 二.解决的方法 2.1 idea 2.2 X-conv算子 2.3 分层卷积 三.实验 3.1分类和分割 3.2消融实验.可视化和模型复杂度 总结 仍存 ...

  2. Input 只能输入正数以及2位小数点

    <input onkeyup="this.value= this.value.match(/\d+(\.\d{0,2})?/) ? this.value.match(/\d+(\.\d ...

  3. Android无障碍宝典-talkback

    http://geek.csdn.net/news/detail/93269 http://geek.csdn.net/news/detail/135867

  4. Python中比较运算符连用的语法规则

    在Python中,比较运用符<.>.<=.>=.== .!=可以连用,但语法规则和其它编程语言不一样 以 == 为例,具体语法规则是: a == b == c == d 等价于 ...

  5. 浙大二院姚克团队发现新的NLRP3炎症小体抑制剂,有望用于治疗炎症疾病

    期刊:Clinical and Translational Medicine 发表时间:2021年7月19日 影响因子:11.492 角膜炎是一种眼科常见疾病,也是我国主要致盲眼病之一,其特征是炎性细 ...

  6. SpringBoot - Bean validation 参数校验

    目录 前言 常见注解 参数校验的应用 依赖 简单的参数校验示例 级联校验 @Validated 与 @Valid 自定义校验注解 前言 后台开发中对参数的校验是不可缺少的一个环节,为了解决如何优雅的对 ...

  7. TCP拥塞控制详解

    1. 拥塞原因与代价 拥塞的代价 当分组的到达速率接近链路容量时,分组经历巨大的排队时延. 发送方必须执行重传以补偿因为缓存溢出而丢弃的分组. 发送方在遇到大时延时进行的不必要重传会引起路由器利用其链 ...

  8. 【javaFX学习】(二) 面板手册

    移至http://blog.csdn.net/qq_37837828/article/details/78732591 更新 找了好几个资料,没找到自己想要的,自己整理下吧,方便以后用的时候挑选,边学 ...

  9. 线程休眠_sleep

    线程休眠_sleep sleep(时间)指定当前线程阻塞的毫秒数: sleep存在异常InterruptedException: sleep时间到达后线程进入就绪状态: sleep可以模拟网络延时,倒 ...

  10. 十进制转十六进制 BASIC-10

    十进制转十六进制 import java.util.Scanner; public class 十进制转十六进制 { /* 十六进制数是在程序设计时经常要使用到的一种整数的表示方式. * 它有0,1, ...