谷歌BERT预训练源码解析(三):训练过程
目录
前言
源码解析
主函数
自定义模型
遮蔽词预测
下一句预测
规范化数据集
前言
本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨。BERT针对两个任务同时训练。1.下一句预测。2.遮蔽词识别
下面介绍BERT的预训练模型run_pretraining.py是怎么训练的。
源码解析
主函数
训练过程主要用了estimator调度器。这个调度器支持自定义训练过程,将训练集传入之后自动训练。详情见注释
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
if not FLAGS.do_train and not FLAGS.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
tf.gfile.MakeDirs(FLAGS.output_dir)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Input Files ***")
for input_file in input_files:
tf.logging.info(" %s" % input_file)
tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig( #训练参数
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
model_fn = model_fn_builder( #自定义模型,用于estimator训练
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=FLAGS.num_train_steps,
num_warmup_steps=FLAGS.num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator( #创建TPUEstimator
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
if FLAGS.do_train: #训练过程
tf.logging.info("***** Running training *****")
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
train_input_fn = input_fn_builder( #创建输入训练集
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=True)
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
if FLAGS.do_eval: #验证过程
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
eval_input_fn = input_fn_builder( #创建验证集
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=False)
result = estimator.evaluate(
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
with tf.gfile.GFile(output_eval_file, "w") as writer:
tf.logging.info("***** Eval results *****")
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
自定义模型
首先获取数据内容,传入到上一篇定义的模型中。对下一句预测任务取出模型的[CLS]结果。对遮蔽词预测任务取出模型的最后结果。然后分别计算loss值,最后将loss值相加。详情见注释
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
#获取数据内容
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
masked_lm_positions = features["masked_lm_positions"]
masked_lm_ids = features["masked_lm_ids"]
masked_lm_weights = features["masked_lm_weights"]
next_sentence_labels = features["next_sentence_labels"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
传入到Bert模型中。
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
#遮蔽预测的batch_loss,平均loss,预测概率矩阵
(masked_lm_loss,
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
bert_config, model.get_sequence_output(), model.get_embedding_table(),
masked_lm_positions, masked_lm_ids, masked_lm_weights)
#下一句预测的batch_loss,平均loss,预测概率矩阵
(next_sentence_loss, next_sentence_example_loss,
next_sentence_log_probs) = get_next_sentence_output(
bert_config, model.get_pooled_output(), next_sentence_labels)
#loss相加
total_loss = masked_lm_loss + next_sentence_loss
#获取所有变量
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
#如果有之前保存的模型
if init_checkpoint:
(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
#如果有之前保存的模型
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer( #自定义好的优化器
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
output_spec = tf.contrib.tpu.TPUEstimatorSpec( #Estimator要求返回一个EstimatorSpec对象
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
#验证过程
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_log_probs, next_sentence_labels):
"""Computes the loss and accuracy of the model."""
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
[-1, masked_lm_log_probs.shape[-1]]) #概率矩阵转成[batch_size*max_pred_pre_seq,vocab_size]
masked_lm_predictions = tf.argmax(
masked_lm_log_probs, axis=-1, output_type=tf.int32) #取最大值位置为输出
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) #每句loss列表 [batch_size*max_pred_per_seq]
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
masked_lm_accuracy = tf.metrics.accuracy( #计算准确率
labels=masked_lm_ids,
predictions=masked_lm_predictions,
weights=masked_lm_weights)
masked_lm_mean_loss = tf.metrics.mean( #计算平均loss
values=masked_lm_example_loss, weights=masked_lm_weights)
next_sentence_log_probs = tf.reshape(
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
next_sentence_predictions = tf.argmax( #获取最大位置为输出
next_sentence_log_probs, axis=-1, output_type=tf.int32)
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
next_sentence_accuracy = tf.metrics.accuracy( #计算准确率
labels=next_sentence_labels, predictions=next_sentence_predictions)
next_sentence_mean_loss = tf.metrics.mean( 计算平均loss
values=next_sentence_example_loss)
return {
"masked_lm_accuracy": masked_lm_accuracy,
"masked_lm_loss": masked_lm_mean_loss,
"next_sentence_accuracy": next_sentence_accuracy,
"next_sentence_loss": next_sentence_mean_loss,
}
eval_metrics = (metric_fn, [
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_log_probs, next_sentence_labels
])
output_spec = tf.contrib.tpu.TPUEstimatorSpec( #Estimator要求返回一个EstimatorSpec对象
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
else:
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
return output_spec
return model_fn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
遮蔽词预测
输入BERT模型的最后一层encoder,输出遮蔽词预测任务的loss和概率矩阵。详情见注释
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
label_ids, label_weights):
#这里的input_tensor是模型中传回的最后一层结果 [batch_size,seq_length,hidden_size]。
#output_weights是词向量表 [vocab_size,embedding_size]
"""Get loss and log probs for the masked LM."""
#获取positions位置的所有encoder(即要预测的那些位置的encoder)
input_tensor = gather_indexes(input_tensor, positions) #[batch_size*max_pred_pre_seq,hidden_size]
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense( #传入一个全连接层 输出shape [batch_size*max_pred_pre_seq,hidden_size]
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size*max_pred_pre_seq,vocab_size]
logits = tf.nn.bias_add(logits, output_bias) #加bias
log_probs = tf.nn.log_softmax(logits, axis=-1) #[batch_size*max_pred_pre_seq,vocab_size]
label_ids = tf.reshape(label_ids, [-1]) #[batch_size*max_pred_per_seq]
label_weights = tf.reshape(label_weights, [-1])
one_hot_labels = tf.one_hot( #[batch_size*max_pred_per_seq,vocab_size]
label_ids, depth=bert_config.vocab_size, dtype=tf.float32) #label id转one hot
# The `positions` tensor might be zero-padded (if the sequence is too
# short to have the maximum number of predictions). The `label_weights`
# tensor has a value of 1.0 for every real prediction and 0.0 for the
# padding predictions.
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) #[batch_size*max_pred_per_seq]
numerator = tf.reduce_sum(label_weights * per_example_loss) #[1] 一个batch的loss
denominator = tf.reduce_sum(label_weights) + 1e-5
loss = numerator / denominator #平均loss
return (loss, per_example_loss, log_probs)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
下一句预测
输入BERT模型CLS的encoder,输出下一句预测任务的loss和概率矩阵,详情见注释
def get_next_sentence_output(bert_config, input_tensor, labels):
#input_tensor shape [batch_size,hidden_size]
"""Get loss and log probs for the next sentence prediction."""
# Simple binary classification. Note that 0 is "next sentence" and 1 is
# "random sentence". This weight matrix is not used after pre-training.
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, bert_config.hidden_size],
initializer=modeling.create_initializer(bert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer()) #[batch_size,hidden_size]
logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size,2]
logits = tf.nn.bias_add(logits, output_bias) #[batch_size,2]
log_probs = tf.nn.log_softmax(logits, axis=-1)
labels = tf.reshape(labels, [-1])
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) #[batch_size,2]
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) #[batch_size]
loss = tf.reduce_mean(per_example_loss) #[1]
return (loss, per_example_loss, log_probs)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
规范化数据集
Estimator要求模型的输入为特定格式(from_tensor_slices),所以要对数据进行类封装
def input_fn_builder(input_files,
max_seq_length,
max_predictions_per_seq,
is_training,
num_cpu_threads=4):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
name_to_features = {
"input_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"input_mask":
tf.FixedLenFeature([max_seq_length], tf.int64),
"segment_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"masked_lm_positions":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_ids":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_weights":
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
"next_sentence_labels":
tf.FixedLenFeature([1], tf.int64),
}
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
d = d.repeat() #重复
d = d.shuffle(buffer_size=len(input_files)) #打乱
# `cycle_length` is the number of parallel files that get read.
cycle_length = min(num_cpu_threads, len(input_files))
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
d = d.apply(
tf.contrib.data.parallel_interleave( #生成嵌套数据集,并且输出其元素隔行交错
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.shuffle(buffer_size=100)
else:
d = tf.data.TFRecordDataset(input_files)
# Since we evaluate for a fixed number of steps we don't want to encounter
# out-of-range exceptions.
d = d.repeat()
# We must `drop_remainder` on training because the TPU requires fixed
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
# and we *don't* want to drop the remainder, otherwise we wont cover
# every sample.
d = d.apply(
tf.contrib.data.map_and_batch( #结构转换
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=True))
return d
return input_fn
---------------------
作者:保持一份率性
来源:CSDN
原文:https://blog.csdn.net/weixin_39470744/article/details/84619903
版权声明:本文为博主原创文章,转载请附上博文链接!

谷歌BERT预训练源码解析(三):训练过程的更多相关文章
- [源码解析] 分布式训练Megatron (1) --- 论文 & 基础
[源码解析] 分布式训练Megatron (1) --- 论文 & 基础 目录 [源码解析] 分布式训练Megatron (1) --- 论文 & 基础 0x00 摘要 0x01 In ...
- Celery 源码解析三: Task 对象的实现
Task 的实现在 Celery 中你会发现有两处,一处位于 celery/app/task.py,这是第一个:第二个位于 celery/task/base.py 中,这是第二个.他们之间是有关系的, ...
- Mybatis源码解析(三) —— Mapper代理类的生成
Mybatis源码解析(三) -- Mapper代理类的生成 在本系列第一篇文章已经讲述过在Mybatis-Spring项目中,是通过 MapperFactoryBean 的 getObject( ...
- 谷歌BERT预训练源码解析(一):训练数据生成
目录预训练源码结构简介输入输出源码解析参数主函数创建训练实例下一句预测&实例生成随机遮蔽输出结果一览预训练源码结构简介关于BERT,简单来说,它是一个基于Transformer架构,结合遮蔽词 ...
- 谷歌BERT预训练源码解析(二):模型构建
目录前言源码解析模型配置参数BertModelword embeddingembedding_postprocessorTransformerself_attention模型应用前言BERT的模型主要 ...
- ReactiveCocoa源码解析(三) Signal代码的基本实现
上篇博客我们详细的聊了ReactiveSwift源码中的Bag容器,详情请参见<ReactiveSwift源码解析之Bag容器>.本篇博客我们就来聊一下信号量,也就是Signal的的几种状 ...
- ReactiveSwift源码解析(三) Signal代码的基本实现
上篇博客我们详细的聊了ReactiveSwift源码中的Bag容器,详情请参见<ReactiveSwift源码解析之Bag容器>.本篇博客我们就来聊一下信号量,也就是Signal的的几种状 ...
- React的React.createRef()/forwardRef()源码解析(三)
1.refs三种使用用法 1.字符串 1.1 dom节点上使用 获取真实的dom节点 //使用步骤: 1. <input ref="stringRef" /> 2. t ...
- 第三十四节,目标检测之谷歌Object Detection API源码解析
我们在第三十二节,使用谷歌Object Detection API进行目标检测.训练新的模型(使用VOC 2012数据集)那一节我们介绍了如何使用谷歌Object Detection API进行目标检 ...
随机推荐
- 使用Jest进行单元测试
Jest是Facebook推出的一款单元测试工具. 安装 npm install --save-dev jest ts-jest @types/jest 在package.json中添加脚本: “te ...
- neo4j批量导入neo4j-import
neo4j数据批量导入 1 neo4j基本参数 1.1 启动与关闭: 1.2 neo4j-admin的参数:控制内存 1.2.1 memrec 是查看参考内存设置 1.2.2 指定缓存–pagecac ...
- 2019.8.5 NOIP模拟测试13 反思总结【已更新完毕】
还没改完题,先留个坑. 放一下AC了的代码,其他东西之后说… 改完了 快下课了先扔代码 跑了跑了 思路慢慢写 来补完了[x 刚刚才发现自己打错了标题 这次考试挺爆炸的XD除了T3老老实实打暴力拿了52 ...
- CenOS SSH无密码登录
系统环境:CentOS6.8 软件环境:SSH(yum -y install openssh-clients) IP 地址:192.168.0.188 用户环境:root.xiaoming 系统 ...
- pip list报错:DEPRECATION: The default format will switch to columns in the future.
一.现象: pip list 显示出以下错误: DEPRECATION: The default format will switch to columns in the future. Yo ...
- arcgis增大缩放级别
<!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title> ...
- 利用CSS使footer固定在页面底部
1.HTML基本结构 <!DOCTYPEhtml> <htmlxmlns="http://www.w3.org/1999/xhtml"> <headr ...
- 本地 vs. 云:大数据厮杀的最终幸存者会是谁?— InfoQ专访阿里云智能通用计算平台负责人关涛
摘要: 本地大数据服务是否进入消失倒计时?云平台大数据服务最终到底会趋向多云.混合云还是单一公有云?集群规模增大,上云成本将难以承受是误区还是事实?InfoQ 将就上述问题对阿里云智能通用计算平台负责 ...
- Function相关的小知识
重载 相同函数名,不同参数列表的多个函数,在调用时可自动根据传入参数的不同,选择对应的函数执行.为什么使用重载: 减轻API的名字,减轻调用者的负担.何时使用重 ...
- 剑指offer 1-6
1. 二维数组中的查找 在一个二维数组中,每一行都按照从左到右递增的顺序排序,每一列都按照从上到下递增的顺序排序.请完成一个函数,输入这样的一个二维数组和一个整数,判断数组中是否含有该整数. 分析 ...