tensorflow保存读取-【老鱼学tensorflow】
当我们对模型进行了训练后,就需要把模型保存起来,便于在预测时直接用已经训练好的模型进行预测。
保存模型的权重和偏置值
假设我们已经训练好了模型,其中有关于weights和biases的值,例如:
import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')
然后我们初始化这些变量的值,假装是训练后被设置上的值:
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
最后进行保存:
# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)
这样在打印出:
保存的路径为: D:/todel/python/saver/save_net.ckpt
在那个目录下,我们看到:
这样,这些训练后的参数就被保存起来了。
完整的保存参数的代码为:
import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)
恢复模型的权重和偏置值
在我们训练好模型并把训练后的权重和偏置值保存了之后,当我们需要进行预测时,只要读取这个已经保存好的权重和偏置值就可以进行预测了。
当然,这里的模型结构还是需要进行创建的,因为我们保存的仅仅是权重值和偏置值。
首先定义要恢复的权重和偏置值的结构:
import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
注意:其中的name要跟之前保存时一致。
然后进行加载:
saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
这样输出为:
weights: [[ 1. 2. 3.]
[ 3. 4. 5.]]
biases: [[ 1. 2. 3.]]
就是前面我们保存的内容被恢复出来了。
完整的恢复代码为:
import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
tensorflow保存读取-【老鱼学tensorflow】的更多相关文章
- tensorflow分类-【老鱼学tensorflow】
前面我们学习过回归问题,比如对于房价的预测,因为其预测值是个连续的值,因此属于回归问题. 但还有一类问题属于分类的问题,比如我们根据一张图片来辨别它是一只猫还是一只狗.某篇文章的内容是属于体育新闻还是 ...
- tensorflow安装-【老鱼学tensorflow】
TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,Tensor ...
- tensorflow例子-【老鱼学tensorflow】
本节主要用一个例子来讲述一下基本的tensorflow用法. 在这个例子中,我们首先伪造一些线性数据点,其实这些数据中本身就隐藏了一些规律,但我们假装不知道是什么规律,然后想通过神经网络来揭示这个规律 ...
- tensorflow变量-【老鱼学tensorflow】
在程序中定义变量很简单,只要定义一个变量名就可以,但是tensorflow有点类似在另外一个世界,因此需要通过当前的世界中跟tensorlfow的世界中进行通讯,来告诉tensorflow的世界中定义 ...
- tensorflow激励函数-【老鱼学tensorflow】
当我们回到家,如果家里有异样,我们能够很快就会发现家中的异样,那是因为这些异常的摆设在我们的大脑中会产生较强的脑电波. 当我们听到某个单词,我们大脑中跟这个单词相关的神经元会异常兴奋,而同这个单词无关 ...
- tensorflow卷积神经网络-【老鱼学tensorflow】
前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...
- tensorflow Tensorboard可视化-【老鱼学tensorflow】
tensorflow自带了可视化的工具:Tensorboard.有了这个可视化工具,可以让我们在调整各项参数时有了可视化的依据. 本次我们先用Tensorboard来可视化Tensorflow的结构. ...
- tensorflow 传入值-【老鱼学tensorflow】
上个文章中讲述了tensorflow中如何定义变量以及如何读取变量的方式,本节主要讲述关于传入值. 变量主要用于在tensorflow系统中经常会被改变的值,而对于传入值,它只是当tensorflow ...
- tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】
之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...
随机推荐
- Scrum【转】
转载自:https://www.cnblogs.com/l2rf/p/5783726.html 灵感来自于一段冷笑话: 一天,一头猪和一只鸡在路上散步,鸡看了一下猪说,“嗨,我们合伙开一家餐馆怎么样? ...
- mysql的服务器构成
什么是实例 这里的实例不是类产生的实例对象,而是Linux系统下的一种机制 1.MySQL的后台进程+线程+预分配的内存结构. 2.MySQL在启动的过程中会启动后台守护进程,并生成工作线程,预分配内 ...
- 《Linux下cp XXX1 XXX2的功能》的实现
<Linux下cp XXX1 XXX2的功能>的实现 一.题目要求 编写MyCP.java 实现类似Linux下cp XXX1 XXX2的功能,要求MyCP支持两个参数: java MyC ...
- JMeter的介绍和简单使用
Apache官网(https://jmeter.apache.org/)对JMeter的解释: Apache JMeter™ Apache JMeter™应用程序是开源软件, 为负载功能和性能测试 ...
- R语言入门(2)-数据对象
数据对象 创建向量相关的方法 R语言的向量用法非常像python, 就比如这个seq(0,10,2), 从0到10, 步长为2, 涉及到的元素作为向量里的内容进行创建. 这里的用法非常像Matlab, ...
- localhost 将您重定向的次数过多
localhost 将您重定向的次数过多 问题描述:在项目中,出现 localhost 将您重定向的次数过多 ,有可能是因为设置重定向的时候,自己重定向到自己,或者重定向成环,导致无限的重定向.检查重 ...
- mysql登录报错“Access denied for user 'root'@'localhost' (using password: YES”)的处理方法
环境 CentosOS 6.5 ,已安装mysql 情景 root密码忘记,使用普通用户无法登录 解决 问题一 无法使用mysql命令 参考文章:https://www.cnblogs.com/com ...
- 使用docker中mysql镜像
1.拉取mysql镜像 docker pull mysql:5.6 2.运行mysql的镜像生成一个正在运行的容器,可以通过docker contain ls得到容器的id信息 docker run ...
- 简单易懂的解释c#的abstract和virtual的用法和区别
先来看abstract方法,顾名思义,abstract方法就是抽象方法. 1.抽象方法就是没有实现的,必须是形如: public abstract void Init(); 2.拥有抽象方法的类必须修 ...
- ZOJ1008
题目: ZOJ 1008 分析: 重排矩阵, 虽然题目给的时间很多, 但是要注意剪枝, 把相同的矩阵标记, 在搜索时可以起到剪枝效果. Code: #include <bits/stdc++.h ...