1.定义

tf.estimator.Estimator(model_fn=model_fn) #model_fn是一个方法

2.定义model_fn:

    def model_fn_builder(self, bert_config, num_labels, init_checkpoint):
"""
:param bert_config:
:param num_labels:
:param init_checkpoint:
:param learning_rate:
:param num_train_steps:
:param num_warmup_steps:
:return:
"""
def model_fn(features, labels, mode, params):
"""
       这4个参数必须这样定义,就算是不用某个参数,也要把它定义出来
:param features: 是estimator传过来的feature
:param labels: 数据标签
:param mode: tf.estimator.TRAIN/tf.estimator.EVAL/tf.estimator.PREDICTION
:param params:这个暂时没弄懂
:return:
"""
input_ids = features['input_ids']
input_mask = features['input_mask']
segment_ids = features['segment_ids']
probabilities = self.creat_model(bert_config, input_ids, input_mask, segment_ids, num_labels) # 这里是重点,这里要定义模型和要取模型的什么值 tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) # assignment_map是模型所有的变量字典,init_checkpoint为模型文件
tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 加载模型 output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities) # 应为上面已经从create_model中获取了我们要做什么op,获取什么值,prediction为op或值
return output_spec return model_fn
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
assignment_map = {}
initialized_variable_names = {} name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var init_vars = tf.train.list_variables(init_checkpoint) assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
continue
assignment_map[name] = name
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1 return (assignment_map, initialized_variable_names)
    def creat_model(self, bert_config, input_ids, input_mask, segment_ids, num_labels):
""" :param bert_config:
:param input_ids:
:param input_mask:
:param segment_ids:
:param num_labels:
:return:
"""
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False) output_layer = model.get_pooled_output() hidden_size = output_layer.shape[-1].value
    
    
    # 获得已经训练好的值  
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
probabilities = tf.nn.softmax(logits, axis=-1) return probabilities

2.使用estimator.predict

def predict(self, text_a, text_b):
""" :param text_a:
:param text_b:
:return:
""" def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f input_ids, input_mask, segment_ids = self.convert_single_example(text_a, text_b) features = collections.OrderedDict()
features['input_ids'] = create_int_feature(input_ids)
features['input_mask'] = create_int_feature(input_mask)
features['segment_ids'] = create_int_feature(segment_ids) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) # 将feature转换为example self.writer.write(tf_example.SerializeToString())# 序列化example,写入tfrecord文件 result = self.estimator.predict(input_fn=self.predict_input_fn)
    def file_based_input_fn_builder(self):
""" :param examples:
:return:
"""
name_to_features = {
"input_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
"input_mask": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
"segment_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
} def decode_record(_examples, _name_to_feature):
""" :param _examples:
:param _name_to_feature:
:return:
""" return tf.parse_single_example(_examples, _name_to_feature) def input_fn():
""" :param params:
:return:
"""
d = tf.data.TFRecordDataset(self.predict_file) # 读取TFRecord文件
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: decode_record(record, name_to_features), # 将序列化的feature映射到字典上
batch_size=1,
drop_remainder=False)) return d # 这里返回的值会进入到定义estimator时的model_fn中,model_fn中的feature是d.get_next()的结果 return input_fn

1

tf.estimator.Estimator的更多相关文章

  1. tf.estimator.Estimator类的用法

    官网链接:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator Estimator - 一种可极大地简化机器学习编程的高阶 ...

  2. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  3. tensorflow创建自定义 Estimator

    https://www.tensorflow.org/guide/custom_estimators?hl=zh-cn 创建自定义 Estimator 本文档介绍了自定义 Estimator.具体而言 ...

  4. Tensorflow1.4 高级接口使用(estimator, data, keras, layers)

    TensorFlow 高级接口使用简介(estimator, keras, data, experiment) TensorFlow 1.4正式添加了keras和data作为其核心代码(从contri ...

  5. TensorFlow 1.4利用Keras+Estimator API进行训练和预测

    Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中 ...

  6. 4. Tensorflow的Estimator实践原理

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  7. 使用 Estimator 构建卷积神经网络

    来源于:https://tensorflow.google.cn/tutorials/estimators/cnn 强烈建议前往学习 tf.layers 模块提供一个可用于轻松构建神经网络的高级 AP ...

  8. 创建自定义 Estimator

    ref 本文档介绍了自定义 Estimator.具体而言,本文档介绍了如何创建自定义 Estimator 来模拟预创建的 Estimator DNNClassifier 在解决鸢尾花问题时的行为.要详 ...

  9. TensorFlow之estimator详解

    Estimator初识 框架结构 在介绍Estimator之前需要对它在TensorFlow这个大框架的定位有个大致的认识,如下图示: 可以看到Estimator是属于High level的API,而 ...

随机推荐

  1. DS8800后端的光纤通道交换式互连方式

    DS8800 使用SAS 硬盘.使用了FC 到SAS 转换,光纤通道交换技术被用于DS8800 后端. FC 技术是普遍用于在一个光纤通道仲裁环路(Fibre Channel Arbitrated L ...

  2. Qt_MainWindow简介

    QMainWindow 是Qt框架带来的一个预定义好的主窗口类.按照建立HelloWorld程序建立工程,直接运行,或有一个空窗口. main().cpp #include "mainwin ...

  3. python生成器实例

    生成器是一种特殊的迭代器,它有yield语句 #coding:utf-8def fibs(max): n,a,b = 0,0,1 while n < max: yield b a , b = b ...

  4. brctl命令

    有五台主机.其中一台主机装有linux ,安装了网桥模块,而且有四块物理网卡,分别连接同一网段的其他主机.我们希望其成为一个网桥,为其他四台主机(IP分别为192.168.1.2 ,192.168.1 ...

  5. .net core 部署 docker (CentOS7)

    最近研究 docker 在Linux 下部署 .net core 项目,在过程中踩了很多坑,网上的资料对我帮助确实大,但有些问题未指明出来. 特地整理一份在发布文档 本文使用的是 root 账号操作, ...

  6. 解决SHAREJPOINT 跨域问题

    目前仅支持IE7/8不支持IE11和谷歌 对于跨域情况,目前找到如果jquery是get获取方式,可以配置web.config相关属性,具体powershell命令如下: Add-PSSnapin M ...

  7. Android 四大组件之" ContentProvider "

    前言 ContentProvider作为Android的四大组件之一,是属于需要掌握的基础知识,可能在我们的应用中,对于Activity和Service这两个组件用的很常见,了解的也很多,但是对Con ...

  8. Linux 安装JavaEE环境之Tomcat安装笔记

    1.先用xftp将tomcat的压缩包上传到 /opt/ 2.在/usr/local/下使用命令mkdir tomcat 创建tomcat目录 将apache-tomcat-7.0.70.tar.gz ...

  9. 【算法python实现】 -- 最大子序和

    原题:https://leetcode-cn.com/problems/maximum-subarray/ 问题描述: 输入:[-2, 1, -3, 4, -1, 2, 1, -5, 4], 输出:6 ...

  10. 《Python黑帽子:黑客与渗透测试编程之道》 自动化攻击取证

    工具安装: 下载源码:https://code.google.com/archive/p/volatility/downloads 工具配置: 获取内存镜像:https://www.downloadc ...