【算法】Bert预训练源码阅读
Bert预训练源码
主要代码
地址:https://github.com/google-research/bert
- create_pretraning_data.py:原始文件转换为训练数据格式
- tokenization.py:汉字,单词切分,复合词处理,create_pretraning_data中调用
- modeling.py: 模型结构
- run_pretraing.py: 运行预训练
tokenization.py
作用:句子切分,特殊符号处理。
主要类:BasicTokenizer, WordpieceTokenizer, FullTokenizer
- BasicTokenizer.tokenize: 文本转为unicode, 去除特殊符号,汉字前后加空格,按空格切分单词,去掉文本重音,按标点符号切割单词。最后生成一个list
- WordpieceTokenizer.tokenize: 长度过长的单词标记为UNK,复合词切分,找不到的词标记为UNK
- FullTokenizer:先后调用BasicTokenizer和WordpieceTokenizer
create_pretraning_data.py
输入:词典, 原始文本(空行分割不同文章,一行一句)
输出:训练数据
作用:生成训练数据,句子对组合,单词mask等
入口函数main
- 加载词典,加载原始文本
- create_training_instances
读取原始文本文件,做unicode转换,中文,标点,特殊符号处理,空格切分,复合词切分。转换为[[[first doc first sentence],[first doc second sentence],[first doc third sentence]],[[second doc first sentence],[]],....] 这样的结构
去除空文章,文章顺序打乱
输入的原始文本会重复使用dupe_factor次 对每一篇文章生成训练数据create_instances_from_document
训练语句长度限制max_seq_length,0.1的概率生成长度较小的训练语句,增加鲁棒性
句子对(A,B)随机组合
对于一篇文章,按顺序获取n行句子,其长度总和限制为target_seq_length,
随机选取n行中的前m行作为A
0.5的概率,B是n行中后面剩余的部分;其他情况,B是随机选取的其他文章内容,开始位置是随机的
文章中没有使用的部分继续组合(A, B)
添加CLS,SEP分隔符,生成句子向量
对句子对中的单词做随机mask (create_masked_lm_predictions), 随机取num_to_predict个单词做mask,0.8的概率标记为MASK,0.1的概率标记为原始单词,0.1的概率标记为随机单词
封装,句子对,句子id,是否为随机下一句,mask的下标位置,mask对应的原始单词训练数据序列化,存入文件。单词转为id,句子长度不足的后面补0。
modeling.py
BertConfig: 配置
BertModel: 模型主体
建模主体过程:
- 获取词向量 [batch_size, seq_length, embedding_size]
- 添加句向量,添加位置向量,在最后一个维度上做归一化,整体做dropout
- transformer
全连接映射 [B*F, embedding_size]->[B*F, N*H]
\(dropout(softmax(QK^T))V\), 其中mask了原本没有数据的部分
全连接,dropout,残差处理,归一化,全连接,dropout,残差处理,归一化
上述循环多层
取最终[CLS]对应的向量做句向量
run_pretraining.py
作用:生成目标函数,加载已有参数,迭代训练
主要函数:model_fn_builder
- 评估mask单词的预测准确性,整体loss为mask处预测对的分数的平均值
- 评估next_sentence预测准确性,loss为预测对的概率值
- 总损失为上面两个损失相加
【算法】Bert预训练源码阅读的更多相关文章
- 谷歌BERT预训练源码解析(一):训练数据生成
目录预训练源码结构简介输入输出源码解析参数主函数创建训练实例下一句预测&实例生成随机遮蔽输出结果一览预训练源码结构简介关于BERT,简单来说,它是一个基于Transformer架构,结合遮蔽词 ...
- 谷歌BERT预训练源码解析(三):训练过程
目录前言源码解析主函数自定义模型遮蔽词预测下一句预测规范化数据集前言本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BERT针对两个任务同 ...
- 谷歌BERT预训练源码解析(二):模型构建
目录前言源码解析模型配置参数BertModelword embeddingembedding_postprocessorTransformerself_attention模型应用前言BERT的模型主要 ...
- Bert源码阅读
前言 对Google开源出来的bert代码,来阅读下.不纠结于代码组织形式,而只是梳理下其训练集的生成,训练的self-attention和multi-head的具体实现. 训练集的生成 主要实现在c ...
- caffe-windows中classification.cpp的源码阅读
caffe-windows中classification.cpp的源码阅读 命令格式: usage: classification string(模型描述文件net.prototxt) string( ...
- 【原】SDWebImage源码阅读(四)
[原]SDWebImage源码阅读(四) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 SDWebImage中主要实现了NSURLConnectionDataDelega ...
- 如何阅读Java源码 阅读java的真实体会
刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比 ...
- 36 网络相关函数(四)——live555源码阅读(四)网络
36 网络相关函数(四)——live555源码阅读(四)网络 36 网络相关函数(四)——live555源码阅读(四)网络 简介 7)createSocket创建socket方法 8)closeSoc ...
- 15 BasicHashTable基本哈希表类(二)——Live555源码阅读(一)基本组件类
这是Live555源码阅读的第一部分,包括了时间类,延时队列类,处理程序描述类,哈希表类这四个大类. 本文由乌合之众 lym瞎编,欢迎转载 http://www.cnblogs.com/oloroso ...
随机推荐
- SpringBoot配置日志logback
1.这里我们选择logback,首先加入pom依赖 <dependency> <groupId>ch.qos.logback</groupId> <artif ...
- 高斯消元与行列式求值 part1
两道模板题,思路与算法却是相当经典. 先说最开始做的行列式求值,题目大致为给一个10*10的行列式,求其值 具体思路(一开始看到题我的思路): 1.暴算,把每种可能组合试一遍,求逆序数,做相应加减运算 ...
- SVG矢量图学习实例
从W3school上学习了一下SVG矢量图形,感觉和HTML相比还是有一些新的元素和属性的,一时间不能全部记住,特此留下笔记,供遗忘时作为参考 <!DOCTYPE html> <!- ...
- Day042---浮动 背景图设置 相对定位绝对定位
1.练习浮动 2.文本属性和字体属性 文本对齐 text-align left 左对齐 right 右对齐 center 中心对齐 justify 两边对齐 只适应于英文 text-indent ...
- idea使用记录
1.在工具栏添加工具
- JGUI源码:响应式布局简单实现(13)
首先自我检讨下,一直没有认真研究过响应式布局,有个大致概念响应式就是屏幕缩小了就自动换行或者隐藏显示,就先按自己的理解来闭门造车思考实现过程吧. 1.首先把显示区域分成12等分,bootstrap是这 ...
- PMP知识点(二)——三点估算的两种方法对活动持续时间估算的影响和如何取舍
一.准备工作 活动持续时间的估算属于PMBOK中第六章项目时间管理中第五节6.6估算活动持续时间的内容. 三点估算是6.5和7.2(估算成本)中应用到的一种工具和技术.数据流向图参考如下: 其应用到的 ...
- VS注释快捷键
注释: 先CTRL+K,然后CTRL+C 取消注释: 先CTRL+K,然后CTRL+U 代码自动对齐:1, ctrl+a 2, ctrl+k 3, ctrl+f
- ultraEdit软件比较两个文件内容的不同处
1.软件名称为:UltraEdit ,安装并打开软件; 软件图标: 打开软件如图所示: 2.点击导航图标,蓝色上面有Uc图标,该图标名称为“比较文件” 如图位置: 3.弹出框,根据文件路径选择好比较的 ...
- SpringMVC+Apache Shiro+JPA(hibernate)案例教学(四)基于Shiro验证用户权限,且给用户授权
最新项目比较忙,写文章的精力就相对减少了,但看到邮箱里的几个催更,还是厚颜把剩下的文档补上. 一.修改ShiroDbRealm类,实现它的doGetAuthorizationInfo方法 packag ...