TensorFlow(十三):模型的保存与载入
一:保存
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
#保存模型
saver.save(sess,'net/my_net.ckpt')
结果:
Iter 0,Testing Accuracy 0.8252
Iter 1,Testing Accuracy 0.8916
Iter 2,Testing Accuracy 0.9008
Iter 3,Testing Accuracy 0.906
Iter 4,Testing Accuracy 0.9091
Iter 5,Testing Accuracy 0.9104
Iter 6,Testing Accuracy 0.911
Iter 7,Testing Accuracy 0.9127
Iter 8,Testing Accuracy 0.9145
Iter 9,Testing Accuracy 0.9166
Iter 10,Testing Accuracy 0.9177
二:载入
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次100张照片
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
saver.restore(sess,'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
结果:
未载入识别率 0.098
INFO:tensorflow:Restoring parameters from net/my_net.ckpt
载入后识别率 0.9177
TensorFlow(十三):模型的保存与载入的更多相关文章
- TensorFlow 模型的保存与载入
参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...
- TensorFlow笔记-模型的保存,恢复,实现线性回归
模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- Tensorflow Learning1 模型的保存和恢复
CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...
- tensorflow:模型的保存和训练过程可视化
在使用tf来训练模型的时候,难免会出现中断的情况.这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始. 保存模型的方法: #之前是各种构建模型graph的操作(矩阵相乘,sig ...
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- 【TensorFlow】TensorFlow基础 —— 模型的保存读取与可视化方法总结
TensorFlow提供了一个用于保存模型的工具以及一个可视化方案 这里使用的TensorFlow为1.3.0版本 一.保存模型数据 模型数据以文件的形式保存到本地: 使用神经网络模型进行大数据量和复 ...
随机推荐
- 【flume】5.采集日志进入hbase
设置我们的flume配置信息 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor li ...
- 动态script标签同步加载 ps:无打包编译,静态实现静态资源入口动态配置,无编译打包静态资源添加版本号
/**功能:创建动态标签加载css ,js文件,重点是js文件,利用onloading加递归实现动态标签的同步加载用法:在html文件body底部script内部声明并调用下列函数,obj中写要加载的 ...
- k8s部署nacos
如果是在centos7上直接启动nacos 注意修改启动命令 sh startup.sh -m standalone 访问路径 http://********:8848/nacos/index.h ...
- Windows终端命令行工具Cmder
在IT这一行,大部分情况下都是推荐大家使用Linux或者类Unix操作系统去编程,Linux作为一代优秀的操作系统,已经人尽皆知,在IT行业已经成为核心.有条件的大佬都选择了使用mac编程,最优秀的莫 ...
- github上更新fork的别人的项目
直接上解决方案的步骤 (1)在自己fork后的项目的位置上,点击New pull request. (2)比较和原创版本(base)的变化 (3 ) compare across forks. 使得左 ...
- JavaScript作用域、作用域链 学习随笔
(本文是这些知识点的自我理解.写之余从头回顾,加深理解.取得更多收获之用.) 作用域(scope) 程序设计概念,通常来说,一段程序代码中所用到的名字(JS叫标识符(如变量名.函数名.属性名.参数.. ...
- dao 接口定义了一个方法,报错 The method xxx is undefined for the type xxx;
转自:https://blog.csdn.net/panshoujia/article/details/78203837 持久层(DAO层)下的一个接口 ,eclipse报了一个The method ...
- 善用#waring,#pragma mark 标记
在项目开发中,我们不可能对着需求一口气将代码都写好.开发过程中肯定遇到诸如需求变动,业务逻辑沟通,运行环境的切换等这些问题.当项目大的时候,如果木有形成统一的代码规范,在项目交接和开发人员沟通上将会带 ...
- linux IPC简单学习
Posix和system v区别 所谓的IPC(进程间通信)指的是消息队列,共享内存,信号量3种机制合并起来,当然,这是个狭义的概念,只包含这三种.IPC又可以分为system v进程间通信和posi ...
- unity 刚体
刚体属性(rigidbody)标明物体受物理影响,包括重力,阻力等等. mass为重量,当大质量物体被小重量物体碰撞时只会发生很小的影响.. Drag现行阻力决定组件在没有发生物理行为下停止移动的速度 ...