参考学习博客:

# https://www.cnblogs.com/felixwang2/p/9190692.html

一、模型保存
 # https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size =
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, ])
y = tf.placeholder(tf.float32, [None, ]) # 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([, ]))
b = tf.Variable(tf.zeros([]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, ), tf.argmax(prediction, )) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
for epoch in range():
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型
saver.save(sess, 'net/my_net.ckpt')
输出结果:
Iter ,Testing Accuracy 0.8629
Iter ,Testing Accuracy 0.896
Iter ,Testing Accuracy 0.9028
Iter ,Testing Accuracy 0.9052
Iter ,Testing Accuracy 0.9085
Iter ,Testing Accuracy 0.9099
Iter ,Testing Accuracy 0.9122
Iter ,Testing Accuracy 0.9139
Iter ,Testing Accuracy 0.9148
Iter ,Testing Accuracy 0.9163
Iter ,Testing Accuracy 0.9165

二、模型载入
 # https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size = 100
# 计算一共有多少批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) # 创建一个简单的神经网络,输入层784个神经单元,输出层10个神经单元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔值列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
未载入识别率 0.098
载入后识别率 0.9178

程序输出如上结果。

TensorFlow 模型的保存与载入的更多相关文章

  1. TensorFlow——训练模型的保存和载入的方法介绍

    我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入. 1.保存模型 首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起 ...

  2. [翻译] Tensorflow模型的保存与恢复

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  3. 三、TensorFlow模型的保存和加载

    1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...

  4. tensorflow模型的保存与恢复

    1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...

  5. Tensorflow模型变量保存

    Tensorflow:模型变量保存 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 pyt ...

  6. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  7. tensorflow模型的保存与恢复,以及ckpt到pb的转化

    转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...

  8. tensorflow模型的保存与加载

    模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...

  9. TensorFlow(十三):模型的保存与载入

    一:保存 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist ...

随机推荐

  1. .NET知识梳理——3.面向对象

    1. 面向对象 1.1        封装.继承.多态理解 1.1.1  封装 封装就是将抽象得到的数据和行为(或功能)相结合,形成一个有机的整体,也就是将数据与操作数据的源代码进行有机的结合,形成“ ...

  2. Maven项目中配置文件导出问题

    1.将该设置写在pom.xml中 <build> <resources> <resource> <directory>src/main/resource ...

  3. 变色html css js

    <!DOCTYPE html><html> <head> <meta charset="utf-8" /> <title> ...

  4. Python入门2 —— 变量

    一:问号三连 1.什么是变量? 变 指的是事物的状态是可以发生变化的 量 指的是记录事物的状态 2.为什么要有变量? 为了让计算机像人一样去记录事物的状态 3.怎么用变量? 先定义 后引用 二:变量的 ...

  5. testng如何实现用例间依赖

    todo: 参考: https://www.cnblogs.com/znicy/p/6534893.html

  6. Oracle 11G在用EXP 导入、导出时,若有空表对导入导出中遇到的问题的解决

    11G中有个新特性,当表无数据时,不分配segment,以节省空间 解决方法: 1.insert一行,再rollback就产生segment了. 该方法是在在空表中插入数据,再删除,则产生segmen ...

  7. Jarvis OJ - DD-Hello -Writeup

    Jarvis OJ - DD-Hello -Writeup 转载请注明出处http://www.cnblogs.com/WangAoBo/p/7239216.html 题目: 分析: 第一次做这道题时 ...

  8. 线索二叉树的详细实现(C++)

    线索二叉树概述 二叉树虽然是非线性结构,但二叉树的遍历却为二又树的结点集导出了一个线性序列.希望很快找到某一结点的前驱或后继,但不希望每次都要对二叉树遍历一遍,这就需要把每个结点的前驱和后继信息记录下 ...

  9. ubuntu 下的ftp详细配置

    FTP(文件传输协议)是一个较老且最常用的标准网络协议,用于在两台计算机之间通过网络上传/下载文件.然而, FTP 最初的时候并不安全,因为它仅通过用户凭证(用户名和密码)传输数据,没有进行加密. 警 ...

  10. Java进阶学习(6)之抽象与接口

    抽象与接口 抽象 抽象函数 表达概念而无法实现具体代码的函数 抽象类 表达概念而无法构造出实体的类 有抽象函数的类也可以有非抽象函数 实现抽象函数 继承自抽象类的子类必须覆盖父类中的抽象函数 抽象 与 ...