tf.estimator.Estimator类的用法
官网链接:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator
Estimator - 一种可极大地简化机器学习编程的高阶 TensorFlow API。Estimator 会封装下列操作:
- 训练
- 评估
- 预测
- 导出以供使用
您可以使用官方提供的预创建的 Estimator,也可以编写自定义 Estimator。所有 Estimator(无论是预创建的还是自定义)都是基于 tf.estimator.Estimator
类的类。
Estimator 的优势
Estimator 具有下列优势:
- 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
- Estimator 简化了在模型开发者之间共享实现的过程。
- 您可以使用高级直观代码开发先进的模型。简言之,采用 Estimator 创建模型通常比采用低阶 TensorFlow API 更简单。
- Estimator 本身在
tf.layers
之上构建而成,可以简化自定义过程。 - Estimator 会为您构建图。
- Estimator 提供安全的分布式训练循环,可以控制如何以及何时:
- 构建图
- 初始化变量
- 开始排队
- 处理异常
- 创建检查点文件并从故障中恢复
- 保存 TensorBoard 的摘要
使用 Estimator 编写应用时,您必须将数据输入管道从模型中分离出来。这种分离简化了不同数据集的实验流程。
预创建的 Estimator
借助预创建的 Estimator,您能够在比基本 TensorFlow API 高级很多的概念层面上进行操作。由于 Estimator 会为您处理所有“管道工作”,因此您不必再为创建计算图或会话而操心。也就是说,预创建的 Estimator 会为您创建和管理 Graph
和 Session
对象。此外,借助预创建的 Estimator,您只需稍微更改下代码,就可以尝试不同的模型架构。例如,DNNClassifier
是一个预创建的 Estimator 类,它根据密集的前馈神经网络训练分类模型。
预创建的 Estimator 程序的结构
依赖预创建的 Estimator 的 TensorFlow 程序通常包含下列四个步骤:
编写一个或多个数据集导入函数。 例如,您可以创建一个函数来导入训练集,并创建另一个函数来导入测试集。每个数据集导入函数都必须返回两个对象:
- 一个字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
- 一个包含一个或多个标签的张量
例如,以下代码展示了输入函数的基本框架:
def input_fn(dataset):
... # manipulate dataset, extracting the feature dict and the label
return feature_dict, label
- 定义特征列。 每个
tf.feature_column
都标识了特征名称、特征类型和任何输入预处理操作。例如,以下代码段创建了三个存储整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个 lambda,该程序将调用此 lambda 来调节原始数据:
# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education',
normalizer_fn=lambda x: x - global_education_mean)
- 实例化相关的预创建的 Estimator。 例如,下面是对名为
LinearClassifier
的预创建 Estimator 进行实例化的示例代码:
# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.LinearClassifier(
feature_columns=[population, crime_rate, median_education],
)
- 调用训练、评估或推理方法。例如,所有 Estimator 都提供训练模型的
train
方法。
# my_training_set is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)
从 Keras 模型创建 Estimator
您可以将现有的 Keras 模型转换为 Estimator。这样做之后,Keras 模型就可以利用 Estimator 的优势,例如分布式训练。调用 tf.keras.estimator.model_to_estimator
,如下例所示:
# Instantiate a Keras inception v3 model.
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3) # Treat the derived Estimator as you would with any other Estimator.
# First, recover the input name(s) of Keras model, so we can use them as the
# feature column name(s) of the Estimator input function:
keras_inception_v3.input_names # print out: ['input_1']
# Once we have the input name(s), we can create the input function, for example,
# for input(s) in the format of numpy ndarray:
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"input_1": train_data},
y=train_labels,
num_epochs=1,
shuffle=False)
# To train, we call Estimator's train function:
est_inception_v3.train(input_fn=train_input_fn, steps=2000)
class Estimator(builtins.object)
一 介绍
Estimator 类,用来训练和验证 TensorFlow 模型。
Estimator 对象包含了一个模型 model_fn,这个模型给定输入和参数,会返回训练、验证或者预测等所需要的操作节点。
所有的输出(检查点、事件文件等)会写入到 model_dir,或者其子文件夹中。如果 model_dir 为空,则默认为临时目录。
config 参数为 tf.estimator.RunConfig 对象,包含了执行环境的信息。如果没有传递 config,则它会被 Estimator 实例化,使用的是默认配置。
params 包含了超参数。Estimator 只传递超参数,不会检查超参数,因此 params 的结构完全取决于开发者。
Estimator 的所有方法都不能被子类覆盖(它的构造方法强制决定的)。子类应该使用 model_fn 来配置母类,或者增添方法来实现特殊的功能。
Estimator 不支持 Eager Execution(eager execution能够使用Python 的debug工具、数据结构与控制流。并且无需使用placeholder、session,计算结果能够立即得出)。
二 类内方法
1、__init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None)
构造一个 Estimator 的实例.。
参数:
model_fn: 模型函数。函数的格式如下:
参数:
1、features: 这是 input_fn 返回的第一项(input_fn 是 train, evaluate 和 predict 的参数)。类型应该是单一的 Tensor 或者 dict。
2、labels: 这是 input_fn 返回的第二项。类型应该是单一的 Tensor 或者 dict。如果 mode 为 ModeKeys.PREDICT,则会默认为 labels=None。如果 model_fn 不接受 mode,model_fn 应该仍然可以处理 labels=None。
3、mode: 可选。指定是训练、验证还是测试。参见 ModeKeys。
4、params: 可选,超参数的 dict。 可以从超参数调整中配置 Estimators。
5、config: 可选,配置。如果没有传则为默认值。可以根据 num_ps_replicas 或 model_dir 等配置更新 model_fn。
返回:
EstimatorSpec
model_dir: 保存模型参数、图等的地址,也可以用来将路径中的检查点加载至 estimator 中来继续训练之前保存的模型。如果是 PathLike, 那么路径就固定为它了。如果是 None,那么 config 中的 model_dir 会被使用(如果设置了的话),如果两个都设置了,那么必须相同;如果两个都是 None,则会使用临时目录。
config: 配置类。
params: 超参数的dict,会被传递到 model_fn。keys 是参数的名称,values 是基本 python 类型。
warm_start_from: 可选,字符串,检查点的文件路径,用来指示从哪里开始热启动。或者是 tf.estimator.WarmStartSettings 类来全部配置热启动。如果是字符串路径,则所有的变量都是热启动,并且需要 Tensor 和词汇的名字都没有变。
异常:
RuntimeError: 开启了 eager execution
ValueError:model_fn 的参数与 params 不匹配
ValueError:这个函数被 Estimator 的子类所覆盖
2、train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None)
根据所给数据 input_fn, 对模型进行训练。
参数:
input_fn:一个函数,提供由小 batches 组成的数据, 供训练使用。必须返回以下之一:
1、一个 'tf.data.Dataset'对象:Dataset的输出必须是一个元组 (features, labels),元组要求如下。
2、一个元组 (features, labels):features 是一个 Tensor 或者一个字典(特征名为 Tensor),labels 是一个 Tensor 或者一个字典(特征名为 Tensor)。features 和 labels 都被 model_fn 所使用,应该符合 model_fn 输入的要求。
hooks:SessionRunHook 子类实例的列表。用于在训练循环内部执行。
steps:模型训练的步数。如果是 None, 则一直训练,直到input_fn 抛出了超过界限的异常。steps 是递进式进行的。如果执行了两次训练(steps=10),则总共训练了 20 次。如果中途抛出了越界异常,则训练在 20 次之前就会停止。如果你不想递进式进行,请换为设置 max_steps。如果设置了 steps,则 max_steps 必须是 None。
max_steps:模型训练的最大步数。如果为 None,则一直训练,直到input_fn 抛出了超过界限的异常。如果设置了 max_steps, 则 steps 必须是 None。如果中途抛出了越界异常,则训练在 max_steps 次之前就会停止。执行两次 train(steps=100) 意味着 200 次训练;但是,执行两次 train(max_steps=100) 意味着第二次执行不会进行任何训练,因为第一次执行已经做完了所有的 100 次。
saving_listeners:CheckpointSaverListener 对象的列表。用于在保存检查点之前或之后立即执行的回调函数。
返回:
self:为了链接下去。
异常:
ValueError:steps 和 max_steps 都不是 None
ValueError:steps 或 max_steps <= 0
3、evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
根据所给数据 input_fn, 对模型进行验证。
对于每一步,执行 input_fn(返回数据的一个 batch)。
一直进行验证,直到:
steps 个 batches 进行完毕,或者
input_fn 抛出了越界异常(OutOfRangeError 或 StopIteration)
参数:
input_fn:一个函数,构造了验证所需的输入数据,必须返回以下之一:
1、一个 'tf.data.Dataset'对象:Dataset的输出必须是一个元组 (features, labels),元组要求如下。
2、一个元组 (features, labels):features 是一个 Tensor 或者一个字典(特征名为 Tensor),labels 是一个 Tensor 或者一个字典(特征名为 Tensor)。features 和 labels 都被 model_fn 所使用,应该符合 model_fn 输入的要求。
steps:模型验证的步数。如果是 None, 则一直验证,直到input_fn 抛出了超过界限的异常。
hooks:SessionRunHook 子类实例的列表。用于在验证内部执行。
checkpoint_path: 用于验证的检查点路径。如果是 None, 则使用 model_dir 中最新的检查点。
name:验证的名字。使用者可以针对不同的数据集运行多个验证操作,比如训练集 vs 测试集。不同验证的结果被保存在不同的文件夹中,且分别出现在 tensorboard 中。
返回:
返回一个字典,包括 model_fn 中指定的评价指标、global_step(包含验证进行的全局步数)
异常:
ValueError:如果 step 小于等于0
ValueError:如果 model_dir 指定的模型没有被训练,或者指定的 checkpoint_path 为空。
4、predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True)
对给出的特征进行预测
参数:
input_fn:一个函数,构造特征。预测一直进行下去,直到 input_fn 抛出了越界异常(OutOfRangeError 或 StopIteration)。函数必须返回以下之一:
1、一个 'tf.data.Dataset'对象:Dataset的输出和以下的限制相同。
2、features:一个 Tensor 或者一个字典(特征名为 Tensor)。features 被 model_fn 所使用,应该符合 model_fn 输入的要求。
3、一个元组,其中第一项为 features。
predict_keys:字符串列表,要预测的键值。当 EstimatorSpec.predictions 是一个 dict 时使用。如果使用了 predict_keys, 那么剩下的预测值会从字典中过滤掉。如果是 None,则返回全部。
hooks:SessionRunHook 子类实例的列表。用于在预测内部回调。
checkpoint_path: 用于预测的检查点路径。如果是 None, 则使用 model_dir 中最新的检查点。
yield_single_examples:If False, yield the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors whose first dimension is not equal to the batch size.
返回:
predictions tensors 的值
异常:
ValueError:model_dir 中找不到训练好的模型。
ValueError:预测值的 batch 长度不同,且 yield_single_examples 为 True。
ValueError:predict_keys 和 predictions 之间有冲突。例如,predict_keys 不是 None,但是 EstimatorSpec.predictions 不是一个 dict。
tf.estimator.Estimator类的用法的更多相关文章
- tf.estimator.Estimator
1.定义 tf.estimator.Estimator(model_fn=model_fn) #model_fn是一个方法 2.定义model_fn: def model_fn_builder(sel ...
- C#中timer类的用法
C#中timer类的用法 关于C#中timer类 在C#里关于定时器类就有3个 1.定义在System.Windows.Forms里 2.定义在System.Threading.Timer类 ...
- C#正则表达式Regex类的用法
C#正则表达式Regex类的用法 更多2014/2/18 来源:C#学习浏览量:36891 学习标签: 正则表达式 Regex 本文导读:正则表达式的本质是使用一系列特殊字符模式,来表示某一类字符串, ...
- 标准C++中的string类的用法总结
标准C++中的string类的用法总结 相信使用过MFC编程的朋友对CString这个类的印象应该非常深刻吧?的确,MFC中的CString类使用起来真的非常的方便好用.但是如果离开了MFC框架,还有 ...
- android中Handle类的用法
android中Handle类的用法 当我们在处理下载或是其他需要长时间执行的任务时,如果直接把处理函数放Activity的OnCreate或是OnStart中,会导致执行过程中整个Activity无 ...
- Handle类的用法
android中Handle类的用法 当我们在处理下载或是其他需要长时间执行的任务时,如果直接把处理函数放Activity的OnCreate或是OnStart中,会导致执行过程中整个Activity无 ...
- android application类的用法
android application类的用法 Application是android系统Framework提供的一个组件,它是单例模式(singleton),即每个应用只有一个实例,用来存储系统的一 ...
- php class类的用法详细总结
以下是对php中class类的用法进行了详细的总结介绍,需要的朋友可以过来参考下 一:结构和调用(实例化): class className{} ,调用:$obj = new className(); ...
- day319 1、正则表达式的定义及使用 2、Date类的用法 3、Calendar类的用法
1.正则表达式的定义及使用2.Date类的用法3.Calendar类的用法 一.正则表达式 ###01正则表达式的概念和作用* A: 正则表达式的概念和作用* a: 正则表达式的概述* 正则表达式也是 ...
随机推荐
- centos7下载
http://archive.kernel.org/centos-vault/7.0.1406/isos/x86_64/
- VS2015离线安装NuGet Package
在一些情况下,VS2015直接安装NuGet Package的时候,速度异常缓慢: 所以还是考虑直接离线安装: Step1: 下载相应的Package https://www.nuget.org/ 然 ...
- 通过linux核映射驱动访问GPIO
1. HPS GPIO原理 1.功能方块图 linux内核是通过Linux内核memory-mapped device驱动访问GPIO控制器的寄存器而控制HPS端用户的LED和KEY的.memory- ...
- 内置函数_zip()
zip() zip()函数用来把多个可迭代对象中的元素压缩到一起,返回一个可迭代的zip对象,其中每个元素都是包含原来的多个可迭代对象对应位置上元素的元组,最终结果中包含的元素个数取决于所有参数序列或 ...
- 分支结构-Switch
/* switch(表达式或变量){ case value1:{ 语句体1; break; } case value2:{ 语句体2; break; } ... default:{ 语句体n+1; b ...
- 瞎搞poj1008
http://poj.org/problem?id=1008 题意: 两种历法: 1.Haab,一年365天,共19个月,前18月有20天(编号为0-19),最后一个月有5天(编号为0-4)pop(1 ...
- Tomcat 多项目部署方法整理
Tomcat 多项目部署方法整理 说明:tomcat-deploy-aaa和tomcat-deploy-bbb是两个不同的web项目,为了方便以下简称aaa和bbb,请先自行创建并跑通 导航: NO1 ...
- 用scp这个命令来通过ssh传输文件
小结: 1. upload files 到 ssh 服务器 localhost $ scp localfile root@172.20.34.**:~/remotepath 2. 从 ssh 服务器d ...
- Android精通:View与ViewGroup,LinearLayout线性布局,RelativeLayout相对布局,ListView列表组件
UI的描述 对于Android应用程序中,所有用户界面元素都是由View和ViewGroup对象构建的.View是绘制在屏幕上能与用户进行交互的一个对象.而对于ViewGroup来说,则是一个用于存放 ...
- PHP、JS、Python,数据库 获取今天是星期几了?[开发篇]
额,这个看起来是一个好简单的问题,但是真正到自己去一行行写的时候,又给忘了,妈蛋.有空就看看吧.今天是星期几?下面就来看看几种不同语言的实现吧! PHP语言 输出当前时间: echo date('Y- ...