深入浅出图神经网络 GCN代码实战
GCN代码实战
书中5.6节的GCN代码实战做的是最经典Cora数据集上的分类,恰当又不恰当的类比Cora之于GNN就相当于MNIST之于机器学习。
有关Cora的介绍网上一搜一大把我就不赘述了,这里说一下Cora这个数据集对应的图是怎么样的。
Cora有2708篇论文,之间有引用关系共5429个,每篇论文作为一个节点,引用关系就是节点之间的边。每篇论文有一个1433维的特征来表示某个词是否在文中出现过,也就是每个节点有1433维的特征。最后这些论文被分为7类。
所以在Cora上训练的目的就是学习节点的特征及其与邻居的关系,根据已知的节点分类对未知分类的节点的类别进行预测。
知道这些应该就OK了,下面来看代码。
数据处理
注释里自己都写了代码引用自PyG我觉得就扫几眼就行了,因为现在常用的数据集两个GNN轮子(DGL和PyG)里都有,现在基本都是直接用,很少自己下原始数据再处理了,所以略过。
GCN层定义
回顾第5章中GCN层的定义:
\]
所以对于一层GCN,就是对输入\(X\),乘一个参数矩阵\(W\),再乘一个算好归一化后的“拉普拉斯矩阵”即可。
来看代码:
class GraphConvolution(nn.Module):
def __init__(self, input_dim, output_dim, use_bias=True):
super(GraphConvolution, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.use_bias = use_bias
self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
if self.use_bias:
self.bias = nn.Parameter(torch.Tensor(output_dim))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight)
if self.use_bias:
init.zeros_(self.bias)
def forward(self, adjacency, input_feature):
support = torch.mm(input_feature, self.weight)
output = torch.sparse.mm(adjacency, support)
if self.use_bias:
output += self.bias
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.input_dim) + ' -> ' \
+ str(self.output_dim) + ')'
定义了一层GCN的输入输出维度和偏置,对于GCN层来说,每一层有自己的\(W\),\(X\)是输入给的,\(\tilde L_{sym}\)是数据集算的,所以只需要定义一个weight
矩阵,注意一下维度就行。
传播的时候只要按照公式\(X'=\sigma(\tilde L_{sym}XW)\)进行一下矩阵乘法就好,注意一个trick:\(\tilde L_{sym}\)是稀疏矩阵,所以先矩阵乘法得到\(XW\),再用稀疏矩阵乘法计算\(\tilde L_{sym}XW\)运算效率上更好。
GCN模型定义
知道了GCN层的定义之后堆叠GCN层就可以得到GCN模型了,两层的GCN就可以取得很好的效果(过深的GCN因为过度平滑的问题会导致准确率下降):
class GcnNet(nn.Module):
def __init__(self, input_dim=1433):
super(GcnNet, self).__init__()
self.gcn1 = GraphConvolution(input_dim, 16)
self.gcn2 = GraphConvolution(16, 7)
def forward(self, adjacency, feature):
h = F.relu(self.gcn1(adjacency, feature))
logits = self.gcn2(adjacency, h)
return logits
这里设置隐藏层维度为16,调到32,64,...都是可以的,我自己试的结果来说没有太大的区别。从隐藏层到输出层直接将输出维度设置为分类的维度就可以得到预测分类。
传播的时候相比于每一层的传播只需要加上激活函数,这里选用ReLU
。
训练
定义模型、损失函数(交叉熵)、优化器:
model = GcnNet(input_dim).to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DACAY)
具体的训练函数注释已经解释的很清楚:
def train():
loss_history = []
val_acc_history = []
model.train()
train_y = tensor_y[tensor_train_mask]
for epoch in range(EPOCHS):
logits = model(tensor_adjacency, tensor_x) # 前向传播
train_mask_logits = logits[tensor_train_mask] # 只选择训练节点进行监督
loss = criterion(train_mask_logits, train_y) # 计算损失值
optimizer.zero_grad()
loss.backward() # 反向传播计算参数的梯度
optimizer.step() # 使用优化方法进行梯度更新
train_acc, _, _ = test(tensor_train_mask) # 计算当前模型训练集上的准确率
val_acc, _, _ = test(tensor_val_mask) # 计算当前模型在验证集上的准确率
# 记录训练过程中损失值和准确率的变化,用于画图
loss_history.append(loss.item())
val_acc_history.append(val_acc.item())
print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(
epoch, loss.item(), train_acc.item(), val_acc.item()))
return loss_history, val_acc_history
对应的测试函数:
def test(mask):
model.eval()
with torch.no_grad():
logits = model(tensor_adjacency, tensor_x)
test_mask_logits = logits[mask]
predict_y = test_mask_logits.max(1)[1]
accuarcy = torch.eq(predict_y, tensor_y[mask]).float().mean()
return accuarcy, test_mask_logits.cpu().numpy(), tensor_y[mask].cpu().numpy()
注意模型得到的分类不是one-hot的,而是对应不同种类的预测概率,所以要test_mask_logits.max(1)[1]
取概率最高的一个作为模型预测的类别。
这些都写好之后直接运行训练函数即可。有需要还可以对train_loss
和validation_accuracy
进行画图,书上也给出了相应的代码,比较简单不再赘述。
深入浅出图神经网络 GCN代码实战的更多相关文章
- 深入浅出图神经网络 第6章 GCN的性质 读书笔记
第6章 GCN的性质 第5章最后讲到GCN结束的有些匆忙,作为GNN最经典的模型,其有很多性质需要我们去理解. 6.1 GCN与CNN的区别与联系 CNN卷积卷的是矩阵某个区域内的值,图卷积在空域视角 ...
- 图机器学习(GML)&图神经网络(GNN)原理和代码实现(前置学习系列二)
项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4990947?contributionType=1 欢迎fork欢迎三连!文章篇幅有限, ...
- Scala 深入浅出实战经典 第64讲:Scala中隐式对象代码实战详解
王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-87讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...
- Scala 深入浅出实战经典 第63讲:Scala中隐式类代码实战详解
王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-87讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...
- Scala 深入浅出实战经典 第52讲:Scala中路径依赖代码实战详解
王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...
- Scala 深入浅出实战经典 第51讲:Scala中链式调用风格的实现代码实战及其在Spark中应用
王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...
- Scala 深入浅出实战经典 第49课 Scala中Variance代码实战(协变)
王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...
- Scala 深入浅出实战经典 第48讲:Scala类型约束代码实战及其在Spark中的应用源码解析
王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...
- Scala 深入浅出实战经典 第47讲:Scala多重界定代码实战及其在Spark中的应用
王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...
随机推荐
- urllib2连接超时设置
#urllib2设置超时 #获取网页的源码 def getHtml(url,i): if i > 2: return try: req = urllib2.Request(url) time.s ...
- 使用BeautifulSoup高效解析网页,再也不用担心睡不着觉了
BeautifulSoup是一个可以从 HTML 或 XML 文件中提取数据的 Python 库 那需要怎么使用呢? 首先我们要安装一下这个库 1.pip install beautifulsoup4 ...
- NVIDIA DGX SUPERPOD 企业解决方案
NVIDIA DGX SUPERPOD 企业解决方案 实现大规模 AI 创新的捷径 NVIDIA DGX SuperPOD 企业解决方案是业界首个支持任何组织大规模实施 AI 的基础架构解决方案.这一 ...
- Docker基本原理概述
Docker基本原理概述 Docker是一个用于开发,交付和运行应用程序的开放平台.Docker能够将应用程序与基础架构分开,从而可以快速交付软件.借助Docker,可以以与管理应用程序相同的方式来管 ...
- 尚硅谷Java——宋红康笔记【day6-day10】
day6 一.数组的概述 1.数组的理解:数组(Array),是多个相同类型数据按一定顺序排列的集合,并使用一个名字命名,并通过编号的方式对这些数据进行统一管理. 2.数组相关的概念: 数组名 元素 ...
- 一篇文章通俗易懂的让你彻底理解 Java 注解
很多Java程序员,对Java的注解一知半解,更有甚者,有的人可能连注解是什么都不知道 本文我们用最简单的 demo , 最通俗最短的语言,带你了解注解到底是什么? 先来简单回顾一下基础,我们知道,J ...
- 【SQLite】教程09-VBA读取SQLite数据之ODBC,及中文乱码问题
VBA使用ODBC Driver for SQLite读SQLite 如下图有这么一个SQlite数据库,我们要读取它 需要先安装ODBC,可以从这里下载: SQLite 3 ODBC Driver ...
- 【ElasticSearch】给ElasticSearch数据库配置慢查询日志
给ElasticSearch引擎配置慢查询日志,可以实时监控搜索过慢的日志.虽然ElasticSearch以快速搜索而出名,但随着数据量的进一步增大或是服务器的一些性能问题,会有可能出现慢查询的情况. ...
- 不下软件,照样可以完美正确格式化树莓派SD卡!(恢复U盘/SD卡到满容量)
树莓派作用千千万,系统崩溃的理由也数不胜数(不要问我为啥知道),所以系统的重装和sd卡的格式化也在所难免.顺便给大家看一下我今天的成果,我不就是不小心摔了一下我的树莓派...我和sd卡一定是冤家! 捡 ...
- 5.22考试总结(NOIP模拟1)
5.22考试总结(NOIP模拟1) 改题记录 T1 序列 题解 暴力思路很好想,分数也很好想\(QAQ\) (反正我只拿了5pts) 正解的话: 先用欧拉筛把1-n的素数筛出来 void get_Pr ...