吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型
# 1. 自定义模型并训练。
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data tf.logging.set_verbosity(tf.logging.INFO) def lenet(x, is_training):
x = tf.reshape(x, shape=[-1, 28, 28, 1]) conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)
conv1 = tf.layers.max_pooling2d(conv1, 2, 2) conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu)
conv2 = tf.layers.max_pooling2d(conv2, 2, 2) fc1 = tf.contrib.layers.flatten(conv2)
fc1 = tf.layers.dense(fc1, 1024)
fc1 = tf.layers.dropout(fc1, rate=0.4, training=is_training)
return tf.layers.dense(fc1, 10) def model_fn(features, labels, mode, params):
predict = lenet(features["image"], mode == tf.estimator.ModeKeys.TRAIN) if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,predictions={"result": tf.argmax(predict, 1)}) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels=labels)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"]) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) eval_metric_ops = {"accuracy": tf.metrics.accuracy(tf.argmax(predict, 1), labels)} return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,eval_metric_ops=eval_metric_ops) mnist = input_data.read_data_sets("F:\\TensorFlowGoogle\\201806-github\\datasets\\MNIST_data", one_hot=False) model_params = {"learning_rate": 0.01}
estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params) train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True) estimator.train(input_fn=train_input_fn, steps=30000)

# 2. 在测试数据上测试模型。
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False) test_results = estimator.evaluate(input_fn=test_input_fn)
accuracy_score = test_results["accuracy"]
print("\nTest accuracy: %g %%" % (accuracy_score*100))
# 3. 预测过程。
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images[:10]},num_epochs=1,shuffle=False) predictions = estimator.predict(input_fn=predict_input_fn)
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["result"]))
吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型的更多相关文章
- 吴裕雄--天生自然TensorFlow高层封装:Estimator-DNNClassifier
# 1. 模型定义. import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-TensorFlow API
# 1. 模型定义. import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist_ ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-多输入输出
# 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-返回值
# 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-CNN
# 1. 数据预处理 import keras from keras import backend as K from keras.datasets import mnist from keras.m ...
- 吴裕雄--天生自然TensorFlow高层封装:使用TFLearn处理MNIST数据集实现LeNet-5模型
# 1. 通过TFLearn的API定义卷机神经网络. import tflearn import tflearn.datasets.mnist as mnist from tflearn.layer ...
- 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型
# 1. 通过TensorFlow-Slim定义卷机神经网络 import numpy as np import tensorflow as tf import tensorflow.contrib. ...
- 吴裕雄--天生自然TensorFlow高层封装:Keras-RNN
# 1. 数据预处理. from keras.layers import LSTM from keras.datasets import imdb from keras.models import S ...
- 吴裕雄--天生自然TensorFlow高层封装:解决ImportError: cannot import name 'tf_utils'
将原来版本的keras卸载了,再安装2.1.5版本的keras就可以了.
随机推荐
- Centos7安装rabbitMQ3.6.0
文章中的erlang和rabbitmq3.6.0 http://pan.baidu.com/s/1c2Nn64w Centos7 系统操作 cd /etc/yum.repos.d/ mv Cent ...
- Codeforces Round #602 (Div. 2, based on Technocup 2020 Elimination Round 3) E. Arson In Berland Forest
E. Arson In Berland Forest The Berland Forest can be represented as an infinite cell plane. Every ce ...
- Django static配置
STATIC_URL = '/static/' # HTML中使用的静态文件夹前缀 STATICFILES_DIRS = [ os.path.join(BASE_DIR, "static&q ...
- LeetCode刷题(持续更新ing……)
准备刷题了!已经预见未来的日子是苦并快乐的了!虽然 N 年前刷过题,但现在感觉数据结构与算法的基本功快忘光了
- mybaits入门学习
学习了简单的mybatis的配置 Bean层: 这个都会很简单 一个完整的Bean 需要getter和setter方法还需要一个空的构造方法和一个满的构造方法. Dao层: 创建一个接口就ok了 pa ...
- 详解BurpSuite软件 请求包 HTTP (9.23 第十天)
HTTP协议基础 HTTP:HyperText Transfer Protocol,超文本传输协议 1.协议特点: 简单快速,请求方式get post head等8中请求方式 无连接(一次请求就断开) ...
- 量化投资_Multicharts数组操作函数_append()追加函数(自定义)
1. Multicharts中关于数组的操作比较麻烦,而且当中所谓的动态数组的定义并不是像其他语言那种的概念.因此要对数组进行元素“”追加“”的话,需要重新更改数组的索引,然后再最后一个位置添加val ...
- 关于jquery js读取excel文件内容 xls xlsx格式 纯前端
附带参考:http://blog.csdn.net/gongzhongnian/article/details/76438555 更详细导入导出:https://www.jianshu.com/p/7 ...
- Android自定义View——自定义ViewPager
第一部分:自定义ViewGroup的使用,手势识别器和Scroller滑动 第二部分:处理滑动监听,处理滑动冲突,增加ViewPager的指示器 常见的滑动冲突:外部滑动方向和内部滑动方向不一 ...
- meta标签小结
1.手机页面所需: <meta name="viewport" content="width=device-width,initial-scale=1.0,mini ...