# 1. 自定义模型并训练。
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data tf.logging.set_verbosity(tf.logging.INFO) def lenet(x, is_training):
x = tf.reshape(x, shape=[-1, 28, 28, 1]) conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)
conv1 = tf.layers.max_pooling2d(conv1, 2, 2) conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu)
conv2 = tf.layers.max_pooling2d(conv2, 2, 2) fc1 = tf.contrib.layers.flatten(conv2)
fc1 = tf.layers.dense(fc1, 1024)
fc1 = tf.layers.dropout(fc1, rate=0.4, training=is_training)
return tf.layers.dense(fc1, 10) def model_fn(features, labels, mode, params):
predict = lenet(features["image"], mode == tf.estimator.ModeKeys.TRAIN) if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,predictions={"result": tf.argmax(predict, 1)}) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels=labels)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"]) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) eval_metric_ops = {"accuracy": tf.metrics.accuracy(tf.argmax(predict, 1), labels)} return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,eval_metric_ops=eval_metric_ops) mnist = input_data.read_data_sets("F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data", one_hot=False) model_params = {"learning_rate": 0.01}
estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params) train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True) estimator.train(input_fn=train_input_fn, steps=30000)

# 2. 在测试数据上测试模型。
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False) test_results = estimator.evaluate(input_fn=test_input_fn)
accuracy_score = test_results["accuracy"]
print("\nTest accuracy: %g %%" % (accuracy_score*100))
# 3. 预测过程。
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images[:10]},num_epochs=1,shuffle=False) predictions = estimator.predict(input_fn=predict_input_fn)
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["result"]))

吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型的更多相关文章

  1. 吴裕雄--天生自然TensorFlow高层封装:Estimator-DNNClassifier

    # 1. 模型定义. import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...

  2. 吴裕雄--天生自然TensorFlow高层封装:Keras-TensorFlow API

    # 1. 模型定义. import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist_ ...

  3. 吴裕雄--天生自然TensorFlow高层封装:Keras-多输入输出

    # 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...

  4. 吴裕雄--天生自然TensorFlow高层封装:Keras-返回值

    # 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...

  5. 吴裕雄--天生自然TensorFlow高层封装:Keras-CNN

    # 1. 数据预处理 import keras from keras import backend as K from keras.datasets import mnist from keras.m ...

  6. 吴裕雄--天生自然TensorFlow高层封装:使用TFLearn处理MNIST数据集实现LeNet-5模型

    # 1. 通过TFLearn的API定义卷机神经网络. import tflearn import tflearn.datasets.mnist as mnist from tflearn.layer ...

  7. 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型

    # 1. 通过TensorFlow-Slim定义卷机神经网络 import numpy as np import tensorflow as tf import tensorflow.contrib. ...

  8. 吴裕雄--天生自然TensorFlow高层封装:Keras-RNN

    # 1. 数据预处理. from keras.layers import LSTM from keras.datasets import imdb from keras.models import S ...

  9. 吴裕雄--天生自然TensorFlow高层封装:解决ImportError: cannot import name 'tf_utils'

    将原来版本的keras卸载了,再安装2.1.5版本的keras就可以了.

随机推荐

  1. InvalidOperationException: Cannot create a DbSet for 'IdentityUserClaim<string>' because this type is not included in the model for the context.

    An unhandled exception occurred while processing the request. InvalidOperationException: Cannot crea ...

  2. Django——整体结构/MVT设计模式

    MVT设计模式 Models      封装数据库,对数据进行增删改查; Views        进行业务逻辑的处理 Templates  进行前端展示 前端展示看到的是模板  -->  发起 ...

  3. 九十七、SAP中ALV事件之十,通过REUSE_ALV_COMMENTARY_WRITE函数来显示ALV的标题

    一.SE37查看REUSE_ALV_COMMENTARY_WRITE函数 二.查看一下导入 三.我们点击SLIS_T_LISTHEADER,来看一下类型 四.我们再看一下,这个info是60长度的字符 ...

  4. 第十七篇 ORM跨表查询和分组查询---二次剖析

    ORM跨表查询和分组查询---二次剖析 阅读目录(Content) 创建表(建立模型) 基于对象的跨表查询 一对多查询(Publish与Book) 多对多查询 (Author 与 Book) 一对一查 ...

  5. jar类库加载顺序

    当我们启动一个tomcat的服务的时候,jar包和claess文件加载顺序: 1. $java_home/lib 目录下的java核心api 2. $java_home/lib/ext 目录下的jav ...

  6. 项目版本回退后出现java compiler level does not match the version of the installed java project facet错误的解决

    今天项目出问题了,采取了项目版本回退的方法解决了代码不能够下拉和上送的问题以后,出现如下错误,项目是微服务的,更新相关的依赖项目,仍得不到解决,检查mapper.xml文件亦没问题.然后在控制台那块发 ...

  7. JS ~ Promise.reject()

    概述: Promise.reject(reason)方法返回一个带有拒绝原因reason参数的Promise对象. 语法 Promise.reject(reason); reason :  表示Pro ...

  8. 【LeetCode】跳跃游戏II

    [问题]给定一个非负整数数组,你最初位于数组的第一个位置.数组中的每个元素代表你在该位置可以跳跃的最大长度.你的目标是使用最少的跳跃次数到达数组的最后一个位置. 示例: 输入: [,,,,] 输出: ...

  9. maven项目中WEB-INF的父目录必须叫webapp吗?

    这个并不是必须的,可以在pom配置文件中修改,如下所示: <webappDirectory>src/main/WebContent</webappDirectory>      ...

  10. gogs 小团队使用 2

    gogs 团队使用第二种方法如下, 前面办法参考前面的方法: 由 root 用户新建 organization, 比如说建立 hardware,然后把团队的 技术负责人拉到 owners 这个 tea ...