estimator同keras是tensorflow的高级API。在tensorflow1.13以上,estimator已经作为一个单独的package从tensorflow分离出来了。
estimator抽象了tensorflow底层的api, 同keras一样,他分离了model和data, 不同于keras这个不得不认养的儿子,estimator作为tensorflow的亲儿子,天生具有分布式的基因,更容易在生产环境里面使用

tensorflow官方文档提供了比较详细的estimator程序的构建过程:
https://www.tensorflow.org/guide#estimators

tensorflow model提供了estimator构建的mnist程序:
https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py

estimator模型由model_fn决定:
官方文档:

其中features, labels是必需的。model, params, config参数是可选的
如下是estiamtor定义的一个模型:

  1. 1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
  1. def (features, labels, mode, params):
    """DNN with three hidden layers and learning_rate=0.1."""
  2.  
  3. net = tf.feature_column.input_layer(features, params['feature_columns'])
    for units in params['hidden_units']:
    net = t 大专栏  tf.estimatorf.layers.dense(net, units=units, activation=tf.nn.relu)
  4.  
  5. # Compute logits (1 per class).
    logits = tf.layers.dense(net, params['n_classes'], activation=None)
  6.  
  7. # Compute predictions.
    predicted_classes = tf.argmax(logits, 1)
    if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
    'class_ids': predicted_classes[:, tf.newaxis],
    'probabilities': tf.nn.softmax(logits),
    'logits': logits,
    }
    return tf.estimator.EstimatorSpec(mode, predictions=predictions)
  8.  
  9. # Compute loss.
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  10.  
  11. # Compute evaluation metrics.
    accuracy = tf.metrics.accuracy(labels=labels,
    predictions=predicted_classes,
    name='acc_op')
    metrics = {'accuracy': accuracy}
    tf.summary.scalar('accuracy', accuracy[1])
  12.  
  13. if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(
    mode, loss=loss, eval_metric_ops=metrics)
  14.  
  15. # Create training op.
    assert mode == tf.estimator.ModeKeys.TRAIN
  16.  
  17. optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

可以看见features, labels分别是模型的input,output。所以是必须的。但是在有的模型里面。比如bert预训练的模型里面,我们不需要训练,所以只用到features。

tf.estimator的更多相关文章

  1. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  2. import tensorflow 报错: tf.estimator package not installed.

    import tensorflow 报错: tf.estimator package not installed. 解决方案1: 安装 pip install tensorflow-estimator ...

  3. tf.estimator.Estimator类的用法

    官网链接:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator Estimator - 一种可极大地简化机器学习编程的高阶 ...

  4. tf.estimator.Estimator

    1.定义 tf.estimator.Estimator(model_fn=model_fn) #model_fn是一个方法 2.定义model_fn: def model_fn_builder(sel ...

  5. 启动Tensorboard时发生错误:class BeholderHook(tf.estimator.SessionRunHook): AttributeError: module 'tensorflow.python.estimator.estimator_lib' has no attribute 'SessionRunHook'

    报错:class BeholderHook(tf.estimator.SessionRunHook):AttributeError: module 'tensorflow.python.estimat ...

  6. tensorflow estimator API小栗子

    TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...

  7. tensorflow创建自定义 Estimator

    https://www.tensorflow.org/guide/custom_estimators?hl=zh-cn 创建自定义 Estimator 本文档介绍了自定义 Estimator.具体而言 ...

  8. Tensorflow1.4 高级接口使用(estimator, data, keras, layers)

    TensorFlow 高级接口使用简介(estimator, keras, data, experiment) TensorFlow 1.4正式添加了keras和data作为其核心代码(从contri ...

  9. TensorFlow 1.4利用Keras+Estimator API进行训练和预测

    Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中 ...

随机推荐

  1. 17.3.12--uillib模块

    1---uillib是python标准库中最常用的一个python网络应用资源访问的模块,他可以让你像访问文本一样,读取网页的内容 它的作用是访问一些不需要验证的网络资源和cookie等 uillib ...

  2. js中call和apply的实现原理

    js中call和apply的实现原理            实现call的思路: /* 还有就是call方法是放在Function().prototype上的也就是构造函数才有的call方法 (我门可 ...

  3. java截取字符串并拼接

    一.substirng public static void main(String[] args) { String sendContent = "请查收:www.baidu.com&qu ...

  4. mui + H5 调取摄像头和相册 实现图片上传

    最近要用MUI做项目,在研究图片上传时 ,遇到了大坑 ,网上搜集各种资料,最终写了一个demo,直接看代码.参考(http://www.cnblogs.com/richerdyoung/p/66123 ...

  5. sqlserver修改某列为自增

    sqlserver如果建表的时候不设自增,之后是没法直接修改的,需要先删再重设: alter table 表名 drop column ID alter table 表名 add ID int ide ...

  6. python学习笔记(31)——日志格式

  7. c语言删除文件的指定行,更新文件

    有时候我们需要删除文件的某一行,来更新文件,在这我个人扩展了一个函数,以删除指定条件的行. static void UpdateHistoryFile(void) { FILE *fin,*fout; ...

  8. Python基础学习五

    Python基础学习五 迭代 for x in 变量: 其中变量可以是字符串.列表.字典.集合. 当迭代字典时,通过字典的内置函数value()可以迭代出值:通过字典的内置函数items()可以迭代出 ...

  9. Linux 进程信号量

    #include<stdlib.h> #include<stdio.h> #include<sys/types.h> #include<sys/ipc.h&g ...

  10. Linux 使用rpm方式安装最新mysql(5.7)步骤以及常见问题解决

    第一步:下载rpm包 mysql官网下载:http://dev.mysql.com/downloads/mysql/ 但如果你的下载网速不好的话也可以点下面的链接下载自己想要的版本 http://mi ...