TensorFlow(十三):模型的保存与载入
一:保存
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
#保存模型
saver.save(sess,'net/my_net.ckpt')
结果:
Iter 0,Testing Accuracy 0.8252
Iter 1,Testing Accuracy 0.8916
Iter 2,Testing Accuracy 0.9008
Iter 3,Testing Accuracy 0.906
Iter 4,Testing Accuracy 0.9091
Iter 5,Testing Accuracy 0.9104
Iter 6,Testing Accuracy 0.911
Iter 7,Testing Accuracy 0.9127
Iter 8,Testing Accuracy 0.9145
Iter 9,Testing Accuracy 0.9166
Iter 10,Testing Accuracy 0.9177
二:载入
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
saver.restore(sess,'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
结果:
未载入识别率 0.098
INFO:tensorflow:Restoring parameters from net/my_net.ckpt
载入后识别率 0.9177
TensorFlow(十三):模型的保存与载入的更多相关文章
- TensorFlow 模型的保存与载入
参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...
- TensorFlow笔记-模型的保存,恢复,实现线性回归
模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- Tensorflow Learning1 模型的保存和恢复
CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...
- tensorflow:模型的保存和训练过程可视化
在使用tf来训练模型的时候,难免会出现中断的情况.这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始. 保存模型的方法: #之前是各种构建模型graph的操作(矩阵相乘,sig ...
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- 【TensorFlow】TensorFlow基础 —— 模型的保存读取与可视化方法总结
TensorFlow提供了一个用于保存模型的工具以及一个可视化方案 这里使用的TensorFlow为1.3.0版本 一.保存模型数据 模型数据以文件的形式保存到本地: 使用神经网络模型进行大数据量和复 ...
随机推荐
- java.lang.AbstractMethodError: null
在使用springcloud的时候运行报这个错,原因是版本冲突导致的,在idea中创建springcloud项目的时候,这里默认是${spring-cloud.version},但是如果你使用的是高版 ...
- 在论坛中出现的比较难的sql问题:13(循环替换问题 过滤各种标点符号)
原文:在论坛中出现的比较难的sql问题:13(循环替换问题 过滤各种标点符号) 所以,觉得有必要记录下来,这样以后再次碰到这类问题,也能从中获取解答的思路. 去掉一个字段中的标点符号的SQL语句怎么写 ...
- 正确使用SQLCipher来加密Android数据库
Android本身自带有不加密的数据库SQLite,如果要保存密码之类的敏感数据在本地的话方法一是使用字段加密解密算法,方法二是整个数据库都加密掉.如果只是加密解密某个字段(如password)就推荐 ...
- NoSql 使用小结
NoSql 使用小结 足够的冗余 如果出现要拿某个 id 去查另外的 collection 的情况,说明应该往这个增加所要查询的字段 实在要做关联查询的话,是不是应该考虑关系型的数据库,关系和非关系混 ...
- kafka的安装及使用(单节点)
介绍了linux环境下,kafka 服务的安装与配置 安装 jdk 环境 下载 kafka 源码包放到服务器,解压 开启 zookeeper 开启 kafka server 创建主题 开启生产者 开启 ...
- MySQL库的相关操作
再熟悉一下Mysql库.表.记录的基本操作. 库 增 create database userinfo1 charset utf8; 查 show databases; show create dat ...
- USB驱动分析
INIT函数: 这是内核模块的初始化函数,其所作的工作只有注册定义好的USB驱动结构体. USB驱动结构体如下: Usb_driver中的probe函数是驱动和设备匹配成功后调用. Usb_drive ...
- 解决Apache启动错误:httpd: Could not reliably determine the server's fully qualified domain name, using 127.0.0.1 for ServerName
启动apache遇到提示: [root@bqh-119 conf]# ../bin/apachectl -thttpd: apr_sockaddr_info_get() failed for bqh- ...
- C#导出和导入Excel模板功能
引用 Aspose.Cells; 基于WinForm 导入 private void btn_excel_input_Click(object sender, EventArgs e) { try ...
- 解决mysql登录警告问题
一.前言 我们在登录mysql的时候经常会看到一句警告: Warning: Using a password on the command line interface can be insecure ...