ERNIE代码解析
原创作者 |疯狂的Max
ERNIE代码解读
考虑到ERNIE使用BRET作为基础模型,为了让没有基础的NLPer也能够理解代码,笔者将先为大家简略的解读BERT模型的结构,完整代码可以参见[1]。
01 BERT的结构组成
BERT的代码最主要的是由分词模块、训练数据预处理、模型结构模块等几部分组成。
1.1 分词模块
模型在训练之前,需要对输入文本进行切分,并将切分的子词转换为对应的ID。这一功能主要由BertTokenizer来实现,主要在
/models/bert/tokenization_bert.py实现。
BertTokenizer 是基于BasicTokenizer和WordPieceTokenizer 的分词器:
BasicTokenizer负责按标点、空格等分割句子,并处理是否统一小写,以及清理非法字符。
WordPieceTokenizer在词的基础上,进一步将词分解为子词(subword)。
具有以下使用方法:
- from_pretrained:从包含词表文件(vocab.txt)的目录中初始化一个分词器;
- tokenize:将文本分解为子词列表;
- convert_tokens_to_ids:将子词转化为子词对应的下标;
- convert_ids_to_tokens :将对应下标转化为子词;
- encode:对于单个句子,分解词并加入特殊词形成“[CLS], x, [SEP]”的结构并转换为词表对应下标的列表;
- decode:将encode的输出转换为句子。
1.2 训练数据预处理
训练数据的构建主要取决于预训练的任务,由于BERT的预训练任务包括预测上下句和掩码词预测是否为连续句,那么其训练数据就需要随机替换连续的语句和其中的分词,这部分功能由run_pretraining.py中的函数
create_instances_from_document实现。
该部分首先构建上下句,拼接[cls]和[sep]等特殊符号的id,构建长度为512的列表,然后根据论文中所使用的指定概率选择要掩码的子词,这部分由函数
create_masked_lm_predictions实现。
1.3 模型结构
BERT模型主要由BertEmbeddings类、BertEncoder类组成,前者负责将子词、位置和上下句标识(segment)投影成向量,后者实现文本的编码。
编码器BertEncoder又由12层相同的编码块BertLayer组成。每一层都由自注意力层BertSelfAttention和前馈神经网络层BertIntermediate以及输出层BertOutput构成,在
/models/bert/modeling_bert.py中实现。
每一层编码层的结构和功能如下:
- BertSelfAttention:负责实现子词之间的相互关注。注意,多头自注意力机制的实现是通过将维度为hidden_size 的表示向量切分成n个维度为hidden_size / n的向量,再对切分的向量分别进行编码,最后拼接编码后的向量实现的;
- BertIntermediate:将批次数据(三维张量)做矩阵相乘和非线性变化;
- BertOutput :实现归一化和残差连接;
工程小技巧: 如果模型在学习表示向量的过程中需要使用不同的编码方式,以结合图神经网络层和Transformer编码层为例,笔者建议尽量使用相同的参数初始化方式,两者都使用残差连接,这能够避免模型训练时出现梯度爆炸的问题。
此外是否需要对注意力权重进行大小的变化,如Transformer会除以向量维度的开方,则取决于图神经网络的层数,一般而言,仅使用两层或以下的图神经网络层,则无需对注意力权重做变化。
具体可以通过观察图神经网络层生成的表示向量的大小是否和Transformer编码层生成的向量大小在同一个数量级来决定,如果在同一个数量级则无需改变注意力权重,如果出现梯度爆炸的现象,那么则可以缩小注意力的权重。
02 从BERT到ERNIE
由于ERNIE是在BERT的基础上进行改进,在数据层面需要构建与文本对应的实体序列,在预训练层面加入了新的预训练任务,那么在代码上就对应着训练数据预处理和模型结构这两方面的改动。因此笔者也将重点针对这两个方面进行讲解,完整代码参见[2]。
其代码结构主要包含两大模块,训练数据预处理模块和模型构建模块。
2.1 训练数据预处理模块
ERNIE模型的知识注入依赖于找到文本中存在的实体,这些实体是指具有意义的抽象或者具象的单个名词或名词短语,我们可以将其称为文本指称项(mention)。一个实体可以有多个别名,也就意味着一个实体可以对应着文本中的多个指称项。
为了能够找到文本语料中实体,作者使用维基百科作为ERNIE的训练语料,将维基百科中具有超链接的名词或者短语作为实体,利用这一现有资源能够大大的简化检索实体的难度。
2.1.1 训练数据构建
在利用现有抽取工具获得语料和实体名文件后,通过
pretrain_data/create_insts.py构建训练数据。
我们知道在训练之前,首先需要对语料进行分词(tokenize),获得子词(tokens),然后根据词典得到子词的索引ID,模型在接收索引后将其投影成向量。从BERT的代码中我们可以知道,BERT首先构建用于下一句预测(next sentences prediction)所需要的上下句,并从中随机选择掩码词,生成用于自注意力阶段的掩码列表。
那么为了能够注入语句中对应的实体,ERNIE就需要在这一过程中创建和训练语料等长的实体ID张量,以及对应的掩码列表。
作者仅仅对文本指称项第一个子词所对应的位置标注实体ID,这也就意味模型仅使用第一个子词向量预测实体。这种做法能够直接复用BERT的代码,而无需单独针对实体序列再构建训练数据,减轻了工程实现的工作量。
for i, x in enumerate(vec):
if x == "#UNK#":
vec[i] = -1
elif x[0] == "Q":
if x in d:
vec[i] = d[x]
if i != 0 and vec[i] == vec[i-1]:
# 以某个实体为例,Q123 Q123 Q123 -> d[Q123] -1 -1,仅在第一个子词中记录实体的ID,其他位置标志为-1
vec[i] = -1
else:
vec[i] = -1
#函数 create_instances_from_document
// 获取句子a和b的实体和子词
tokens = [101] + tokens_a + [102] + tokens_b + [102]
entity = [-1] + entity_a + [-1] + entity_b + [-1]
// 构造用于为数据构建索引的对象ds,并将对应的输入语料id列表及掩码列表,实体id列表和掩码列表等训练数据存入ds。
ds.add_item(torch.IntTensor(input_ids+input_mask+segment_ids
+masked_lm_labels+entity+entity_mask+[next_sentence_label]))
2.1.2 实体向量加载
BERT由于具有经过预训练的向量表,子词的ID值可以利用nn.embedding模块获取投影向量。
那么实体的向量是经过TransE表示学习获得的,又应该如何让模型获取其投影向量呢?作者在code/iteration.py中自定义数据迭代器对象,该对象在返回数据时会调用
torch.utils.data.DataLoader,通过在该函数中传入负责投影实体向量的函数collate_fn,能够让模型在加载数据时获取实体的表示向量。
#类 EpochBatchIterator(object):
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
# collate_fn是传入实体向量的关键
collate_fn=self.collate_fn,
batch_sampler=batches,
))
#函数collate_fn:
def collate_fn(x):
x = torch.LongTensor([xx for xx in x])
entity_idx = x[:, 4*args.max_seq_length:5*args.max_seq_length]
# embed = torch.nn.Embedding.from_pretrained(embed)
# embed为加载了经过预训练的二维实体张量
uniq_idx = np.unique(entity_idx.numpy())
ent_candidate = embed(torch.LongTensor(uniq_idx+1))
2.2 模型结构模块
在模型方面,作者依旧使用12层Transformer编码层作为模型结构,与BERT所不同的是,在前6层沿用BERT的Transformer编码层,但在第7层自定义知识融合层BertLayerMix,首次对经过对齐的实体向量和指称项向量求和,并将其分别传输给知识编码模块和文本编码模块,在剩下5层自定义知识编码层BertLayer,分别对经过融合了两者信息的实体序列和文本序列使用自注意力机制编码。
模型的前5层就是论文所指的文本编码器,后面的7层编码层则构成了论文中的知识编码器。
对于BERT的Transformer编码层,由于第一部分已经介绍过,就不再赘述。下文主要针对作者自定义的编码层做详细解读。
2.2.1 知识融合层BertLayerMix
具体来说,知识融合层BertLayerMix由自注意力层BertAttention_simple、融合层BertIntermediate以及输出层BertOutput构成。
class BertLayerMix(nn.Module):
def __init__(self, config):
super(BertLayerMix, self).__init__()
self.attention = BertAttention_simple(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
# 该编码层仅针对文本进行自注意力操作、矩阵相乘和残差连接
def forward(self, hidden_states, attention_mask, hidden_states_ent, attention_mask_ent, ent_mask):
attention_output = self.attention(hidden_states, attention_mask)
attention_output_ent = hidden_states_ent * ent_mask
# intermediate层负责实体和文本向量求和,并对求和向量非线性变化
intermediate_output = self.intermediate(attention_output, attention_output_ent)
# 然后通过输出层output再次归一化和残差连接
layer_output, layer_output_ent = self.output(intermediate_output, attention_output, attention_output_ent)
return layer_output, layer_output_ent
自注意力层BertAttention_simple由BertSelfAttention和BertSelfOutput构成,前者负责对文本进行自注意力操作,实现上与BERT的自注意力操作相同,就不再展示代码。后者则用于对向量进行矩阵变化和残差连接,生成attention_output
class BertAttention_simple(nn.Module):
def __init__(self, config):
super(BertAttention_simple, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
前馈神经网络层BertIntermediate负责将两者进行线性变化转换为同一维度,求和并做非线性变化。
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.dense_ent = nn.Linear(100, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states, hidden_states_ent):
# 线性变化转换为同一维度
hidden_states_ = self.dense(hidden_states)
hidden_states_ent_ = self.dense_ent(hidden_states_ent)
# 求和并使用intermediate_act_fn做非线性变化
hidden_states = self.intermediate_act_fn(hidden_states_+hidden_states_ent_)
return hidden_states
最终使用BertOutput分别对文本向量和实体向量做矩阵相乘,将经过融合的向量和两者残差连接,并做归一化操作。
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dense_ent = nn.Linear(config.intermediate_size, 100)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm_ent = BertLayerNorm(100, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states_, input_tensor, input_tensor_ent):
# 针对文本向量矩阵相乘
hidden_states = self.dense(hidden_states_)
hidden_states = self.dropout(hidden_states)
# 针对文本向量残差连接和归一化
hidden_states = self.LayerNorm(hidden_states + input_tensor)
# 针对实体向量的矩阵相乘、残差连接和归一化
hidden_states_ent = self.dense_ent(hidden_states_)
hidden_states_ent = self.dropout(hidden_states_ent)
hidden_states_ent = self.LayerNorm_ent(hidden_states_ent + input_tensor_ent)
return hidden_states, hidden_states_ent
2.2.2 知识编码层BertLayer
该编码层针对融合后的实体向量和文本向量分别进行自注意力编码,从而使实体序列中的所有实体也能够实现相互关注。
再次基础上实体向量将和对应位置的文本向量求和,将实体信息传递给文本向量,从而使整个文本序列在下一个编码层中实现对实体序列的关注。
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask, hidden_states_ent, attention_mask_ent, ent_mask):
attention_output, attention_output_ent = self.attention(hidden_states, attention_mask, hidden_states_ent, attention_mask_ent)
attention_output_ent = attention_output_ent * ent_mask
intermediate_output = self.intermediate(attention_output, attention_output_ent)
layer_output, layer_output_ent = self.output(intermediate_output, attention_output, attention_output_ent)
# layer_output_ent = layer_output_ent * ent_mask
return layer_output, layer_output_ent
这一编码层自定义了自注意力层,其中针对实体的自注意力层仅使用4个注意力头。
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
config_ent = copy.deepcopy(config)
config_ent.hidden_size = 100
config_ent.num_attention_heads = 4
self.self_ent = BertSelfAttention(config_ent)
self.output_ent = BertSelfOutput(config_ent)
def forward(self, input_tensor, attention_mask, input_tensor_ent, attention_mask_ent):
# BertSelfAttention对文本向量进行自注意力操作
self_output = self.self(input_tensor, attention_mask)
self_output_ent = self.self_ent(input_tensor_ent, attention_mask_ent)
# BertSelfAttention对实体向量进行自注意力操作
attention_output = self.output(self_output, input_tensor)
attention_output_ent = self.output_ent(self_output_ent, input_tensor_ent)
return attention_output, attention_output_ent
输出层同知识融合层一样,都是使用BERToutput实现归一化和残差连接。
03 源代码参考
[1] https://github.com/google-research/bert
[2] https://github.com/thunlp/ERNIE
私信我领取目标检测与R-CNN/数据分析的应用/电商数据分析/数据分析在医疗领域的应用/NLP学员项目展示/中文NLP的介绍与实际应用/NLP系列直播课/NLP前沿模型训练营等干货学习资源。
ERNIE代码解析的更多相关文章
- VBA常用代码解析
031 删除工作表中的空行 如果需要删除工作表中所有的空行,可以使用下面的代码. Sub DelBlankRow() DimrRow As Long DimLRow As Long Dimi As L ...
- [nRF51822] 12、基础实验代码解析大全 · 实验19 - PWM
一.PWM概述: PWM(Pulse Width Modulation):脉冲宽度调制技术,通过对一系列脉冲的宽度进行调制,来等效地获得所需要波形. PWM 的几个基本概念: 1) 占空比:占空比是指 ...
- [nRF51822] 11、基础实验代码解析大全 · 实验16 - 内部FLASH读写
一.实验内容: 通过串口发送单个字符到NRF51822,NRF51822 接收到字符后将其写入到FLASH 的最后一页,之后将其读出并通过串口打印出数据. 二.nRF51822芯片内部flash知识 ...
- [nRF51822] 10、基础实验代码解析大全 · 实验15 - RTC
一.实验内容: 配置NRF51822 的RTC0 的TICK 频率为8Hz,COMPARE0 匹配事件触发周期为3 秒,并使能了TICK 和COMPARE0 中断. TICK 中断中驱动指示灯D1 翻 ...
- [nRF51822] 9、基础实验代码解析大全 · 实验12 - ADC
一.本实验ADC 配置 分辨率:10 位. 输入通道:5,即使用输入通道AIN5 检测电位器的电压. ADC 基准电压:1.2V. 二.NRF51822 ADC 管脚分布 NRF51822 的ADC ...
- java集合框架之java HashMap代码解析
java集合框架之java HashMap代码解析 文章Java集合框架综述后,具体集合类的代码,首先以既熟悉又陌生的HashMap开始. 源自http://www.codeceo.com/arti ...
- Kakfa揭秘 Day8 DirectKafkaStream代码解析
Kakfa揭秘 Day8 DirectKafkaStream代码解析 今天让我们进入SparkStreaming,看一下其中重要的Kafka模块DirectStream的具体实现. 构造Stream ...
- linux内存管理--slab及其代码解析
Linux内核使用了源自于 Solaris 的一种方法,但是这种方法在嵌入式系统中已经使用了很长时间了,它是将内存作为对象按照大小进行分配,被称为slab高速缓存. 内存管理的目标是提供一种方法,为实 ...
- MYSQL常见出错mysql_errno()代码解析
如题,今天遇到怎么一个问题, 在理论上代码是不会有问题的,但是还是报了如上的错误,把sql打印出來放到DB中却可以正常执行.真是郁闷,在百度里面 渡 了很久没有相关的解释,到时找到几个没有人回复的 & ...
随机推荐
- 【LeetCode】144. Binary Tree Preorder Traversal 解题报告(Python&C++&Java)
作者: 负雪明烛 id: fuxuemingzhu 个人博客:http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 递归 迭代 日期 题目地址:https://leetc ...
- 【LeetCode】763. Partition Labels 解题报告(Python & C++)
作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 解题方法 日期 题目地址:https://leetcode.com/pr ...
- 【LeetCode】201. Bitwise AND of Numbers Range 解题报告(Python)
[LeetCode]201. Bitwise AND of Numbers Range 解题报告(Python) 标签: LeetCode 题目地址:https://leetcode.com/prob ...
- 来自Java程序员的Python新手入门小结
欢迎访问我的GitHub https://github.com/zq2599/blog_demos 内容:所有原创文章分类汇总及配套源码,涉及Java.Docker.Kubernetes.DevOPS ...
- 【汇编语言】李忠《x86汇编语言——从实模式到保护模式》
该书配套资料网址已经失效 配套资料和章节答案下载 查看最新作者网址:http://www.lizhongc.com/ 勘误表:https://wenku.baidu.com/view/9213288b ...
- Gradient-based Hyperparameter Optimization through Reversible Learning
目录 概 主要内容 算法 finite precision arithmic 实验 Maclaurin D, Duvenaud D, Adams R P, et al. Gradient-based ...
- WebRTC源码开发(一)MacOS下源码下载、编译及Demo运行
工作需要测试网络传输算法,逐学习WebRTC源码 工作环境 Mac OS 10.14 Xcode 10.2.1 源码下载 从google(需要[你懂的]) 首先[你懂的] 打开终端,输入curl ww ...
- 在页面中添加两个 <select> 标签,用来显示年份和月份;同时添加两个 <ul> 标签,一个用来显示星期,另一个用来显示日期 在 JavaScript 脚本中动态添加年份和月份,获取当前日期的年份
查看本章节 查看作业目录 需求说明: 使用 JavaScript 中的 Date 对象,在页面上显示一个万年历.选择不同的年份和月份,在页面中显示当前月的日历 实现思路: 在页面中添加两个 <s ...
- TortoiseGit使用技巧
汇总TortoiseGit使用技巧,包括提交代码,创建patch等等. 1.提交代码到本地仓库 在Git工程目录下右键, 点击 Git Commit -> "master". ...
- ActiveMQ基础教程(四):.net core集成使用ActiveMQ消息队列
接上一篇:ActiveMQ基础教程(三):C#连接使用ActiveMQ消息队列 这里继续说下.net core集成使用ActiveMQ.因为代码比较多,所以放到gitee上:https://gitee ...