1.ERNIESage运行实例介绍(1.8x版本)

本项目原链接:https://aistudio.baidu.com/aistudio/projectdetail/5097085?contributionType=1

本项目主要是为了直接提供一个可以运行ERNIESage模型的环境,

https://github.com/PaddlePaddle/PGL/blob/develop/examples/erniesage/README.md

在很多工业应用中,往往出现如下图所示的一种特殊的图:Text Graph。顾名思义,图的节点属性由文本构成,而边的构建提供了结构信息。如搜索场景下的Text Graph,节点可由搜索词、网页标题、网页正文来表达,用户反馈和超链信息则可构成边关系。

ERNIESage 由PGL团队提出,是ERNIE SAmple aggreGatE的简称,该模型可以同时建模文本语义与图结构信息,有效提升 Text Graph 的应用效果。其中 ERNIE 是百度推出的基于知识增强的持续学习语义理解框架。

ERNIESage 是 ERNIE 与 GraphSAGE 碰撞的结果,是 ERNIE SAmple aggreGatE 的简称,它的结构如下图所示,主要思想是通过 ERNIE 作为聚合函数(Aggregators),建模自身节点和邻居节点的语义与结构关系。ERNIESage 对于文本的建模是构建在邻居聚合的阶段,中心节点文本会与所有邻居节点文本进行拼接;然后通过预训练的 ERNIE 模型进行消息汇聚,捕捉中心节点以及邻居节点之间的相互关系;最后使用 ERNIESage 搭配独特的邻居互相看不见的 Attention Mask 和独立的 Position Embedding 体系,就可以轻松构建 TextGraph 中句子之间以及词之间的关系。

使用ID特征的GraphSAGE只能够建模图的结构信息,而单独的ERNIE只能处理文本信息。通过PGL搭建的图与文本的桥梁,ERNIESage能够很简单的把GraphSAGE以及ERNIE的优点结合一起。以下面TextGraph的场景,ERNIESage的效果能够比单独的ERNIE以及GraphSAGE模型都要好。

ERNIESage可以很轻松地在PGL中的消息传递范式中进行实现,目前PGL在github上提供了3个版本的ERNIESage模型:

  • ERNIESage v1: ERNIE 作用于text graph节点上;
  • ERNIESage v2: ERNIE 作用在text graph的边上;
  • ERNIESage v3: ERNIE 作用于一阶邻居及起边上;

主要会针对ERNIESageV1和ERNIESageV2版本进行一个介绍。

1.1算法实现

可能有同学对于整个项目代码文件都不太了解,因此这里会做一个比较简单的讲解。

核心部分包含:

  • 数据集部分
  1. data.txt - 简单的输入文件,格式为每行query \t answer,可作简单的运行实例使用。
  • 模型文件和配置部分
  1. ernie_config.json - ERNIE模型的配置文件。
  2. vocab.txt - ERNIE模型所使用的词表。
  3. ernie_base_ckpt/ - ERNIE模型参数。
  4. config/ - ERNIESage模型的配置文件,包含了三个版本的配置文件。
  • 代码部分
  1. local_run.sh - 入口文件,通过该入口可完成预处理、训练、infer三个步骤。
  2. preprocessing文件夹 - 包含dump_graph.py, tokenization.py。在预处理部分,我们首先需要进行建图,将输入的文件构建成一张图。由于我们所研究的是Text Graph,因此节点都是文本,我们将文本表示为该节点对应的node feature(节点特征),处理文本的时候需要进行切字,再映射为对应的token id。
  3. dataset/ - 该文件夹包含了数据ready的代码,以便于我们在训练的时候将训练数据以batch的方式读入。
  4. models/ - 包含了ERNIESage模型核心代码。
  5. train.py - 模型训练入口文件。
  6. learner.py - 分布式训练代码,通过train.py调用。
  7. infer.py - infer代码,用于infer出节点对应的embedding。
  • 评价部分
  1. build_dev.py - 用于将我们的验证集修改为需要的格式。
  2. mrr.py - 计算MRR值。

要在这个项目中运行模型其实很简单,只要运行下方的入口命令就ok啦!但是,需要注意的是,由于ERNIESage模型比较大,所以如果AIStudio中的CPU版本运行模型容易出问题。因此,在运行部署环境时,建议选择GPU的环境。

另外,如果提示出现了GPU空间不足等问题,我们可以通过调小对应yaml文件中的batch_size来调整,也可以修改ERNIE模型的配置文件ernie_config.json,将num_hidden_layers设小一些。在这里,我仅提供了ERNIESageV2版本的gpu运行过程,如果同学们想运行其他版本的模型,可以根据需要修改下方的命令。

运行完毕后,会产生较多的文件,这里进行简单的解释。

  1. workdir/ - 这个文件夹主要会存储和图相关的数据信息。
  2. output/ - 主要的输出文件夹,包含了以下内容:(1)模型文件,根据config文件中的save_per_step可调整保存模型的频率,如果设置得比较大则可能训练过程中不会保存模型; (2)last文件夹,保存了停止训练时的模型参数,在infer阶段我们会使用这部分模型参数;(3)part-0文件,infer之后的输入文件中所有节点的Embedding输出。

为了可以比较清楚地知道Embedding的效果,我们直接通过MRR简单判断一下data.txt计算出来的Embedding结果,此处将data.txt同时作为训练集和验证集。

1.2 核心模型代码讲解

首先,我们可以通过查看models/model_factory.py来判断在本项目有多少种ERNIESage模型。

from models.base import BaseGNNModel
from models.ernie import ErnieModel
from models.erniesage_v1 import ErnieSageModelV1
from models.erniesage_v2 import ErnieSageModelV2
from models.erniesage_v3 import ErnieSageModelV3 class Model(object):
@classmethod
def factory(cls, config):
name = config.model_type
if name == "BaseGNNModel":
return BaseGNNModel(config)
if name == "ErnieModel":
return ErnieModel(config)
if name == "ErnieSageModelV1":
return ErnieSageModelV1(config)
if name == "ErnieSageModelV2":
return ErnieSageModelV2(config)
if name == "ErnieSageModelV3":
return ErnieSageModelV3(config)
else:
raise ValueError

可以看到一共有ERNIESage模型一共有3个版本,另外我们也提供了基本的GNN模型和ERNIE模型,感兴趣的同学可以自行查阅。

接下来,我主要会针对ERNIESageV1和ERNIESageV2这两个版本的模型进行关键部分的讲解,主要的不同其实就是消息传递机制(Message Passing)部分的不同。

1.2.1 ERNIESageV1关键代码

# ERNIESageV1的Message Passing代码
# 查找路径:erniesage_v1.py(__call__中的self.gnn_layers) -> base.py(BaseNet类中的gnn_layers方法) -> message_passing.py # erniesage_v1.py
def __call__(self, graph_wrappers):
inputs = self.build_inputs()
feature = self.build_embedding(graph_wrappers, inputs[-1]) # 将节点的文本信息利用ERNIE模型建模,生成对应的Embedding作为feature
features = self.gnn_layers(graph_wrappers, feature) # GNN模型的主要不同,消息传递机制入口
outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
outputs.append(src_real_index)
return inputs, outputs # base.py -> BaseNet
def gnn_layers(self, graph_wrappers, feature):
features = [feature]
initializer = None
fc_lr = self.config.lr / 0.001
for i in range(self.config.num_layers):
if i == self.config.num_layers - 1:
act = None
else:
act = "leaky_relu"
feature = get_layer(
self.config.layer_type, # 对于ERNIESageV1, 其layer_type="graphsage_sum",可以到config文件夹中查看
graph_wrappers[i],
feature,
self.config.hidden_size,
act,
initializer,
learning_rate=fc_lr,
name="%s_%s" % (self.config.layer_type, i))
features.append(feature)
return features # message_passing.py
def graphsage_sum(gw, feature, hidden_size, act, initializer, learning_rate, name):
"""doc"""
msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
neigh_feature = gw.recv(msg, sum_recv) # Recv
self_feature = feature
self_feature = fluid.layers.fc(self_feature,
hidden_size,
act=act,
param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer,
learning_rate=learning_rate),
bias_attr=name+"_l.b_0"
)
neigh_feature = fluid.layers.fc(neigh_feature,
hidden_size,
act=act,
param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer,
learning_rate=learning_rate),
bias_attr=name+"_r.b_0"
)
output = fluid.layers.concat([self_feature, neigh_feature], axis=1)
output = fluid.layers.l2_normalize(output, axis=1)
return output

通过上述代码片段可以看到,关键的消息传递机制代码就是graphsage_sum函数,其中send、recv部分如下。

def copy_send(src_feat, dst_feat, edge_feat):
"""doc"""
return src_feat["h"] msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
neigh_feature = gw.recv(msg, sum_recv) # Recv

通过代码可以看到,ERNIESageV1版本,其主要是针对节点邻居,直接将当前节点的邻居节点特征求和。再看到graphsage_sum函数中,将邻居节点特征进行求和后,得到了neigh_feature。随后,我们将节点本身的特征self_feature和邻居聚合特征neigh_feature通过fc层后,直接concat起来,从而得到了当前gnn layer层的feature输出。

1.2.2ERNIESageV2关键代码

ERNIESageV2的消息传递机制代码主要在erniesage_v2.py和message_passing.py,相对ERNIESageV1来说,代码会相对长了一些。

为了使得大家对下面有关ERNIE模型的部分能够有所了解,这里先贴出ERNIE的主模型框架图。

具体的代码解释可以直接看注释。

# ERNIESageV2的Message Passing代码

# 下面的函数都在erniesage_v2.py的ERNIESageV2类中
# ERNIESageV2的调用函数
def __call__(self, graph_wrappers):
inputs = self.build_inputs()
feature = inputs[-1]
features = self.gnn_layers(graph_wrappers, feature)
outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
outputs.append(src_real_index)
return inputs, outputs # 进入self.gnn_layers函数
def gnn_layers(self, graph_wrappers, feature):
features = [feature] initializer = None
fc_lr = self.config.lr / 0.001 for i in range(self.config.num_layers):
if i == self.config.num_layers - 1:
act = None
else:
act = "leaky_relu" feature = self.gnn_layer(
graph_wrappers[i],
feature,
self.config.hidden_size,
act,
initializer,
learning_rate=fc_lr,
name="%s_%s" % ("erniesage_v2", i))
features.append(feature)
return features
接下来会进入ERNIESageV2主要的代码部分。

可以看到,在ernie_send函数用于将我们的邻居信息发送到当前节点。在ERNIESageV1中,我们在Send阶段对邻居节点通过ERNIE模型得到Embedding后,再直接求和,实际上当前节点和邻居节点之间的文本信息在消息传递过程中是没有直接交互的,直到最后才**concat**起来;而ERNIESageV2中,在Send阶段,源节点和目标节点的信息会直接concat起来,通过ERNIE模型得到一个统一的Embedding,这样就得到了源节点和目标节点的一个信息交互过程,这个部分可以查看下面的ernie_send函数。

gnn_layer函数中包含了三个函数:
1. ernie_send: 将src和dst节点对应文本concat后,过Ernie后得到需要的msg,更加具体的解释可以看下方代码注释。
2. build_position_ids: 主要是为了创建位置ID,提供给Ernie,从而可以产生position embeddings。
3. erniesage_v2_aggregator: gnn_layer的入口函数,包含了消息传递机制,以及聚合后的消息feature处理过程。
# 进入self.gnn_layer函数
def gnn_layer(self, gw, feature, hidden_size, act, initializer, learning_rate, name):
def build_position_ids(src_ids, dst_ids): # 此函数用于创建位置ID,可以对应到ERNIE框架图中的Position Embeddings
# ...
pass
def ernie_send(src_feat, dst_feat, edge_feat):
"""doc"""
# input_ids,可以对应到ERNIE框架图中的Token Embeddings
cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1)
src_ids = L.concat([cls, src_feat["term_ids"]], 1)
dst_ids = dst_feat["term_ids"]
term_ids = L.concat([src_ids, dst_ids], 1) # sent_ids,可以对应到ERNIE框架图中的Segment Embeddings
sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1) # position_ids,可以对应到ERNIE框架图中的Position Embeddings
position_ids = build_position_ids(src_ids, dst_ids) term_ids.stop_gradient = True
sent_ids.stop_gradient = True
ernie = ErnieModel( # ERNIE模型
term_ids, sent_ids, position_ids,
config=self.config.ernie_config)
feature = ernie.get_pooled_output() # 得到发送过来的msg,该msg是由src节点和dst节点的文本特征一起过ERNIE后得到的embedding
return feature
def erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name):
feature = L.unsqueeze(feature, [-1])
msg = gw.send(ernie_send, nfeat_list=[("term_ids", feature)]) # Send
neigh_feature = gw.recv(msg, lambda feat: F.layers.sequence_pool(feat, pool_type="sum")) # Recv,直接将发送来的msg根据dst节点来相加。 # 接下来的部分和ERNIESageV1类似,将self_feature和neigh_feature通过concat、normalize后得到需要的输出。
term_ids = feature
cls = L.fill_constant_batch_size_like(term_ids, [-1, 1, 1], "int64", 1)
term_ids = L.concat([cls, term_ids], 1)
term_ids.stop_gradient = True
ernie = ErnieModel(
term_ids, L.zeros_like(term_ids),
config=self.config.ernie_config)
self_feature = ernie.get_pooled_output()
self_feature = L.fc(self_feature,
hidden_size,
act=act,
param_attr=F.ParamAttr(name=name + "_l.w_0",
learning_rate=learning_rate),
bias_attr=name+"_l.b_0"
)
neigh_feature = L.fc(neigh_feature,
hidden_size,
act=act,
param_attr=F.ParamAttr(name=name + "_r.w_0",
learning_rate=learning_rate),
bias_attr=name+"_r.b_0"
)
output = L.concat([self_feature, neigh_feature], axis=1)
output = L.l2_normalize(output, axis=1)
return output
return erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name)

2.总结

通过以上两个版本的模型代码简单的讲解,我们可以知道他们的不同点,其实主要就是在消息传递机制的部分有所不同。ERNIESageV1版本只作用在text graph的节点上,在传递消息(Send阶段)时只考虑了邻居本身的文本信息;而ERNIESageV2版本则作用在了边上,在Send阶段同时考虑了当前节点和其邻居节点的文本信息,达到更好的交互效果。

图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用的更多相关文章

  1. Github项目推荐-图神经网络(GNN)相关资源大列表

    文章发布于公号[数智物语] (ID:decision_engine),关注公号不错过每一篇干货. 转自 | AI研习社 作者|Zonghan Wu 这是一个与图神经网络相关的资源集合.相关资源浏览下方 ...

  2. LUSE: 无监督数据预训练短文本编码模型

    LUSE: 无监督数据预训练短文本编码模型 1 前言 本博文本应写之前立的Flag:基于加密技术编译一个自己的Python解释器,经过半个多月尝试已经成功,但考虑到安全性问题就不公开了,有兴趣的朋友私 ...

  3. 【猫狗数据集】使用预训练的resnet18模型

    数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...

  4. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...

  5. 深度学习tensorflow实战笔记 用预训练好的VGG-16模型提取图像特征

    1.首先就要下载模型结构 首先要做的就是下载训练好的模型结构和预训练好的模型,结构地址是:点击打开链接 模型结构如下: 文件test_vgg16.py可以用于提取特征.其中vgg16.npy是需要单独 ...

  6. 1.keras实现-->自己训练卷积模型实现猫狗二分类(CNN)

    原数据集:包含 25000张猫狗图像,两个类别各有12500 新数据集:猫.狗 (照片大小不一样) 训练集:各1000个样本 验证集:各500个样本 测试集:各500个样本 1= 狗,0= 猫 # 将 ...

  7. 学习AI之NLP后对预训练语言模型——心得体会总结

    一.学习NLP背景介绍:      从2019年4月份开始跟着华为云ModelArts实战营同学们一起进行了6期关于图像深度学习的学习,初步了解了关于图像标注.图像分类.物体检测,图像都目标物体检测等 ...

  8. 【中文版 | 论文原文】BERT:语言理解的深度双向变换器预训练

    BERT:Pre-training of Deep Bidirectional Transformers for Language Understanding 谷歌AI语言组论文<BERT:语言 ...

  9. 预训练语言模型整理(ELMo/GPT/BERT...)

    目录 简介 预训练任务简介 自回归语言模型 自编码语言模型 预训练模型的简介与对比 ELMo 细节 ELMo的下游使用 GPT/GPT2 GPT 细节 微调 GPT2 优缺点 BERT BERT的预训 ...

  10. 预训练中Word2vec,ELMO,GPT与BERT对比

    预训练 先在某个任务(训练集A或者B)进行预先训练,即先在这个任务(训练集A或者B)学习网络参数,然后存起来以备后用.当我们在面临第三个任务时,网络可以采取相同的结构,在较浅的几层,网络参数可以直接加 ...

随机推荐

  1. 引擎之旅 Chapter.2 线程库

    预备知识可参考我整理的博客 Windows编程之线程:https://www.cnblogs.com/ZhuSenlin/p/16662075.html Windows编程之线程同步:https:// ...

  2. 插入排序C语言版本

    算法思路:        每趟将一个待排序的元素作为关键字,按照其关键字值的大小插入到已经排好的部分的适当位置上,直到插入完成.        数组中待排序的关键字前面的数据为已经排序的数据,关键字插 ...

  3. Openstack neutron:目录

    为什么? 最近一直在学习SDN方面的知识,本着"最好的学习就是分享"的精神,记录下本系列的文章,尝试更好地去理解SDN这一正当红的技术. 如何? SDN领域现在已经充斥了大量的公司 ...

  4. 第六章:Django 综合篇 - 7:网站地图sitemap

    网站地图是根据网站的结构.框架.内容,生成的导航网页,是一个网站所有链接的容器.很多网站的连接层次比较深,蜘蛛很难抓取到,网站地图可以方便搜索引擎或者网络蜘蛛抓取网站页面,了解网站的架构,为网络蜘蛛指 ...

  5. SCI论文写作指南

    目录 科技论文的特点 时态的使用 论文的逻辑结构 作者 选择期刊 写作 Title/论文题名 题名 题名的作用 题名基本要求 作者 作者姓名的拼音表达方式 作者单位名与地址的标署 摘要的写作与关键词 ...

  6. Vue子->父组件传值

    父组件引入: Import Test from'' 父页面使用: <Test ref="test" @m1="m2"><Test/> 子 ...

  7. 十分钟速成DevOps实践

    摘要:以华为云软件开发平台DevCloud为例,十分钟简单体验下DevOps应用上云实践--H5经典小游戏上云. 本文分享自华为云社区<<DevOps实践秘籍>十分钟速成DevOps ...

  8. 关于Linux下aws-cli-2版本的安装

    AWS CLI 版本 2 是 AWS CLI 的最新主版本,支持所有最新功能.版本 2 中引入的某些功能无法向后兼容版本 1,您必须升级才能访问这些功能. AWS CLI 版本 2 仅可作为捆绑安装程 ...

  9. 华为网关交换机开启DHCP server服务

    华为网关交换机可以配置基于全局地址池的DHCP服务器,也可以配置基于接口地址池的DHCP服务器,本人比较倾向于配置基于接口地址池的DHCP服务器,因此在这里只介绍后者. 第一步:开启DHCP功能 [S ...

  10. PHP全栈开发(三):CentOS 7 中 PHP 环境搭建及检测

    简单回顾一下我们在(一).(二)中所做的工作. 首先我们在(一)中设置了CentOS 7的网络. 其实这些工作在CentOS 6中都是很容易的,因为有鸟哥的Linux私房菜这样好的指导. 但是这些操作 ...