GCN代码实战

书中5.6节的GCN代码实战做的是最经典Cora数据集上的分类,恰当又不恰当的类比Cora之于GNN就相当于MNIST之于机器学习。

有关Cora的介绍网上一搜一大把我就不赘述了,这里说一下Cora这个数据集对应的是怎么样的。

Cora有2708篇论文,之间有引用关系共5429个,每篇论文作为一个节点,引用关系就是节点之间的边。每篇论文有一个1433维的特征来表示某个词是否在文中出现过,也就是每个节点有1433维的特征。最后这些论文被分为7类。

所以在Cora上训练的目的就是学习节点的特征及其与邻居的关系,根据已知的节点分类对未知分类的节点的类别进行预测。

知道这些应该就OK了,下面来看代码。

数据处理

注释里自己都写了代码引用自PyG我觉得就扫几眼就行了,因为现在常用的数据集两个GNN轮子(DGL和PyG)里都有,现在基本都是直接用,很少自己下原始数据再处理了,所以略过。

GCN层定义

回顾第5章中GCN层的定义:

\[X'=\sigma(\tilde L_{sym}XW)
\]

所以对于一层GCN,就是对输入\(X\),乘一个参数矩阵\(W\),再乘一个算好归一化后的“拉普拉斯矩阵”即可。

来看代码:

  1. class GraphConvolution(nn.Module):
  2. def __init__(self, input_dim, output_dim, use_bias=True):
  3. super(GraphConvolution, self).__init__()
  4. self.input_dim = input_dim
  5. self.output_dim = output_dim
  6. self.use_bias = use_bias
  7. self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
  8. if self.use_bias:
  9. self.bias = nn.Parameter(torch.Tensor(output_dim))
  10. else:
  11. self.register_parameter('bias', None)
  12. self.reset_parameters()
  13. def reset_parameters(self):
  14. init.kaiming_uniform_(self.weight)
  15. if self.use_bias:
  16. init.zeros_(self.bias)
  17. def forward(self, adjacency, input_feature):
  18. support = torch.mm(input_feature, self.weight)
  19. output = torch.sparse.mm(adjacency, support)
  20. if self.use_bias:
  21. output += self.bias
  22. return output
  23. def __repr__(self):
  24. return self.__class__.__name__ + ' (' \
  25. + str(self.input_dim) + ' -> ' \
  26. + str(self.output_dim) + ')'

定义了一层GCN的输入输出维度和偏置,对于GCN层来说,每一层有自己的\(W\),\(X\)是输入给的,\(\tilde L_{sym}\)是数据集算的,所以只需要定义一个weight矩阵,注意一下维度就行。

传播的时候只要按照公式\(X'=\sigma(\tilde L_{sym}XW)\)进行一下矩阵乘法就好,注意一个trick:\(\tilde L_{sym}\)是稀疏矩阵,所以先矩阵乘法得到\(XW\),再用稀疏矩阵乘法计算\(\tilde L_{sym}XW\)运算效率上更好。

GCN模型定义

知道了GCN层的定义之后堆叠GCN层就可以得到GCN模型了,两层的GCN就可以取得很好的效果(过深的GCN因为过度平滑的问题会导致准确率下降):

  1. class GcnNet(nn.Module):
  2. def __init__(self, input_dim=1433):
  3. super(GcnNet, self).__init__()
  4. self.gcn1 = GraphConvolution(input_dim, 16)
  5. self.gcn2 = GraphConvolution(16, 7)
  6. def forward(self, adjacency, feature):
  7. h = F.relu(self.gcn1(adjacency, feature))
  8. logits = self.gcn2(adjacency, h)
  9. return logits

这里设置隐藏层维度为16,调到32,64,...都是可以的,我自己试的结果来说没有太大的区别。从隐藏层到输出层直接将输出维度设置为分类的维度就可以得到预测分类。

传播的时候相比于每一层的传播只需要加上激活函数,这里选用ReLU

训练

定义模型、损失函数(交叉熵)、优化器:

  1. model = GcnNet(input_dim).to(DEVICE)
  2. criterion = nn.CrossEntropyLoss().to(DEVICE)
  3. optimizer = optim.Adam(model.parameters(),
  4. lr=LEARNING_RATE,
  5. weight_decay=WEIGHT_DACAY)

具体的训练函数注释已经解释的很清楚:

  1. def train():
  2. loss_history = []
  3. val_acc_history = []
  4. model.train()
  5. train_y = tensor_y[tensor_train_mask]
  6. for epoch in range(EPOCHS):
  7. logits = model(tensor_adjacency, tensor_x) # 前向传播
  8. train_mask_logits = logits[tensor_train_mask] # 只选择训练节点进行监督
  9. loss = criterion(train_mask_logits, train_y) # 计算损失值
  10. optimizer.zero_grad()
  11. loss.backward() # 反向传播计算参数的梯度
  12. optimizer.step() # 使用优化方法进行梯度更新
  13. train_acc, _, _ = test(tensor_train_mask) # 计算当前模型训练集上的准确率
  14. val_acc, _, _ = test(tensor_val_mask) # 计算当前模型在验证集上的准确率
  15. # 记录训练过程中损失值和准确率的变化,用于画图
  16. loss_history.append(loss.item())
  17. val_acc_history.append(val_acc.item())
  18. print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(
  19. epoch, loss.item(), train_acc.item(), val_acc.item()))
  20. return loss_history, val_acc_history

对应的测试函数:

  1. def test(mask):
  2. model.eval()
  3. with torch.no_grad():
  4. logits = model(tensor_adjacency, tensor_x)
  5. test_mask_logits = logits[mask]
  6. predict_y = test_mask_logits.max(1)[1]
  7. accuarcy = torch.eq(predict_y, tensor_y[mask]).float().mean()
  8. return accuarcy, test_mask_logits.cpu().numpy(), tensor_y[mask].cpu().numpy()

注意模型得到的分类不是one-hot的,而是对应不同种类的预测概率,所以要test_mask_logits.max(1)[1]取概率最高的一个作为模型预测的类别。

这些都写好之后直接运行训练函数即可。有需要还可以对train_lossvalidation_accuracy进行画图,书上也给出了相应的代码,比较简单不再赘述。

深入浅出图神经网络 GCN代码实战的更多相关文章

  1. 深入浅出图神经网络 第6章 GCN的性质 读书笔记

    第6章 GCN的性质 第5章最后讲到GCN结束的有些匆忙,作为GNN最经典的模型,其有很多性质需要我们去理解. 6.1 GCN与CNN的区别与联系 CNN卷积卷的是矩阵某个区域内的值,图卷积在空域视角 ...

  2. 图机器学习(GML)&图神经网络(GNN)原理和代码实现(前置学习系列二)

    项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4990947?contributionType=1 欢迎fork欢迎三连!文章篇幅有限, ...

  3. Scala 深入浅出实战经典 第64讲:Scala中隐式对象代码实战详解

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-87讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  4. Scala 深入浅出实战经典 第63讲:Scala中隐式类代码实战详解

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-87讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  5. Scala 深入浅出实战经典 第52讲:Scala中路径依赖代码实战详解

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  6. Scala 深入浅出实战经典 第51讲:Scala中链式调用风格的实现代码实战及其在Spark中应用

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  7. Scala 深入浅出实战经典 第49课 Scala中Variance代码实战(协变)

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  8. Scala 深入浅出实战经典 第48讲:Scala类型约束代码实战及其在Spark中的应用源码解析

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  9. Scala 深入浅出实战经典 第47讲:Scala多重界定代码实战及其在Spark中的应用

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

随机推荐

  1. 【odoo】【知识点】视图的继承逻辑

    背景:同一个模块,两组开发人员对同一个模型的form视图进行了二开.在没有指定外部ID的情况下,odoo是如何选择展示展示哪个视图呢? 上干货 odoo在加载视图的时候,首先调用的models.py中 ...

  2. Git使用总结(包含Git Bash和Git GUI的使用)(转自CSDN)

    基本命令 初始化设置 配置本机的用户名和Email地址 $ git config --global user.name "Your Name" $ git config --glo ...

  3. curl 常用操作总结

    前言 curl 是一个强大的命令行工具,支持 HTTP, HTTPS, SCP 等多种协议,本文主要总结一下其常用的功能,方便及时查阅. curl --version curl 7.68.0 (x86 ...

  4. YOLO v4分析

    YOLO v4分析 YOLO v4 的作者共有三位:Alexey Bochkovskiy.Chien-Yao Wang 和 Hong-Yuan Mark Liao.其中一作 Alexey Bochko ...

  5. nvGRAPH API参考分析(二)

    nvGRAPH API参考分析(二) nvGRAPH Code Examples 本文提供了简单的示例. 1. nvGRAPH convert topology example void check( ...

  6. springboot——发送put、delete请求

    在springmvc中我们要发送put和delete请求,需要先配置一个过滤器HiddenHttpMethodFilter,而springboot中,已经帮我们自动配置了,所以我们可以不用配置这个过滤 ...

  7. 新增秒杀功能、优惠券、支付宝、Docker,newbee-mall升级版开源啦!

    最近是非常非常非常忙,一方面是公司的事情比较多,另外⼀点是最近在准备诉讼材料.⾄于诉讼的是谁,⼤家可以去看我之前写的几篇文章,所以本来这周是不打算更新文章的.不过,昨天慕课网的法务联系我的律师了,终于 ...

  8. 【NX二次开发】创建老版的基准平面uf5374

    使用uf5374() 源码: double dP1[3] = { 0.0,0.0,0.0 }; double dP2[3] = { 0.0,1.0,0.0 }; double dP3[3] = { 0 ...

  9. 13:Linux虚拟机网络连接异常

    这两个服务需要启动

  10. 使用VS code编写C++无法实时检测代码的解决办法

    更新:其实微软是有官方文档配置VS code 的C++的.地址是: https://code.visualstudio.com/docs/cpp 更改工作区后就发现不能再使用VS CODE愉快地写C+ ...