【转载】 Tensorflow学习笔记-模型保存与加载
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/lovelyaiq/article/details/78646401
————————————————
保存模型时,文件格式有两种,ckpt和pb格式,这两种格式的模型区别是什么呢?首先看一下英文的解释。并且我们的学习中也要养成看英文文档的习惯,其一:老外写的东西通俗易懂,其二,在翻译时,每个人的英文理解不同,原汁原味的道理就没有了。
The .ckpt is the model given by tensorflow which includes all the
weights/parameters in the model. The .pb file stores the computational
graph. To make tensorflow work we need both the graph and the
parameters. There are two ways to get the graph:
(1) use the python program that builds it in the first place (tensorflowNetworkFunctions.py).
(2) Use a .pb file (which would have to be generated by tensorflowNetworkFunctions.py).
.ckpt file is were all the intelligence is.
使用Tensorflow训练好模型之后,我们需要将训练好的模型保存起来,方便以后的使用,这就是Tensorflow模型的持久化。
保存
Tensorflow的模型保存时有几点需要注意:
1、利用tf.train.write_graph() 默认情况下只导出了网络的定义(没有权重weight)。
2、利用tf.train.Saver().save() 导出的文件graph_def与权重是分离的,就像上述英文的描述。
我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标
import tensorflow as tf
v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(v1))
print(sess.run(v2))
print(sess.run(result))
saver.save(sess,'model/model.ckpt')
模型保存后,在model目录将会有三个文件。在Tensorflow版本0.11之前,这三个文件为:meta、ckpt、checkpopint,它们保存的内容如下:
model.ckpt.meta保存计算图的结构,即神经网络的结构
checkpoint保存一个目录下所有的模型文件列表。
ckpt 保存程序中每一个变量的取值。
在Tensorflow版本0.11之后,有四个文件分别为:meta、.data、.index、checkpoint。其中.data文件为模型中的训练变量。
模型加载
模型加载包含两种方式,它们的区分以是否含有计算图上的所有运算。
包含所有运算
import tensorflow as tf v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
print(sess.run(v1+v2))
这种方法加载模型时和保存模型时的代码基本上是一致的,唯一不同的就是没有变量的初始化过程。
模型加载的时候,如果某个变量没有被加载,则系统将会报错。我们可否使用已经定义好的其它变量来加载呢?当然是可以了,因为Tensorflow是支持的,这需要通过字典的形式来完成,将模型中的变量名重名为我们已经定好的其它变量名。
import tensorflow as tf x = tf.Variable(tf.constant(1,shape = [1]),name='x')
y = tf.Variable(tf.constant(2,shape = [1]),name='y')
result = x + y # 通过字典将变量重命名
saver = tf.train.Saver(
{'v1':x,'v2':y}) with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
out = tf.get_default_graph().get_tensor_by_name('add:0')
print(sess.run(out))
使用变量的滑动平均值的模型保存与加载详见:http://blog.csdn.net/lovelyaiq/article/details/78647850
不包含所有运算
import tensorflow as tf saver = tf.train.import_meta_graph('model/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt') #获取节点名称
result = tf.get_default_graph().get_tensor_by_name("add:0")
print(sess.run(result))
Saver类
模型的加载与保存都使用到Saver类,该类的初始化参数为:
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
这里面主要用到的参数:
max_to_keep:保存checkpoint文件的最大数量,默认值为5.
keep_checkpoint_every_n_hours:经过多长时间后,只保留一个checkpoint文件,这是方便验证模型训练多长时间后的性能。默认值为10000.0。
而tf.train.save的参数为:
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True):
使用global_step和write_meta_graph两个参数可以很好的保存模型。
saver.save(sess, 'my_test_model',global_step=1000)
#保存的文件为:
#my_test_model-1000.index
#my_test_model-1000.meta
#my_test_model-1000.data-00000-of-00001
#checkpoint
模型在保存的时候,计算图在第一次已经保存过,并且随着训练的进行,计算图是不会改变的,因此以后的保存,就可以使用write_meta_graph=True不保存计算图。
saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)
tf.train.Saver()默认保存与加载计算图上所有信息。但有时我们只需要保存或加载部分信息。比如在测试或离线预测时,只需知道如何从神经网络的输入层经过前向传播到输出层即可,而不需要类似于变量的初始化、模型保存等辅助节点的信息。而且有时将变量的取值与计算图分开保存是不方便的,因此就需要借助 convert_variables_to_constants 将计算图上所有的变量及其取值通过常量保存,这样整个计算图将会保存到一个文件中。
关于 convert_variables_to_constants 的源码定义如下:从解释中看出,当把网络完全转换为single GraphDef file,它可以删除与加载和保存变量相关的很多操作。
def convert_variables_to_constants(sess, input_graph_def, output_node_names,variable_names_whitelist=None,variable_names_blacklist=None):
"""Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op) # 导出计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。
graph_def = tf.get_default_graph().as_graph_def() # print(graph_def) # 在这里我们只关心"add"节点,因此其它的节点就没有必要导出。
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add']) # 将导出的模型保存到本地
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString())
导出模型的恢复:
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def保存的图加入到当前默认的图
result = tf.import_graph_def(graph_def,return_elements=['add:0'])
print(sess.run(result))
上述方法有一个缺点,那就是我们不能自己定义一个网络输入的placeholder接口,这是不是很蛋筒,不要着急,Tensorflow是可以满足我们的需求。
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 with tf.variable_scope('foo'):
x = tf.get_variable('x',shape=[1],initializer=tf.constant_initializer(1.0))
y = tf.get_variable('y', shape=[1], initializer=tf.constant_initializer(2.0))
# v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
# v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
input_tensor = tf.placeholder(tf.float32,shape=[1],name='input-x')
new_tensor = tf.placeholder(tf.float32, shape=[1], name='input-y') result = tf.add((x+y),input_tensor,name='sum') data = np.array([15], dtype=np.float32) init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op)
# print(sess.run(result,feed_dict={input_tensor:data}))
# print(sess.run(result))
graph_def = tf.get_default_graph().as_graph_def()
# print(graph_def)
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['foo/sum'])
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString()) # 模型恢复
with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) # 使用input_map将模型中的placeholder通信映射到重新定义的placeholder。
result1 = tf.import_graph_def(graph_def ,input_map={'foo/input-x:0':new_tensor},return_elements=['foo/sum:0'],name='') # [array([ 18.], dtype=float32)]
print(sess.run(result1,feed_dict={new_tensor:data}))
这种模型恢复的方法在迁移学习中是常用的方法,至于什么是迁移学习,请参考博客:
【转载】 Tensorflow学习笔记-模型保存与加载的更多相关文章
- 深度学习-05(tensorflow模型保存与加载、文件读取、图像分类:手写体识别、服饰识别)
文章目录 深度学习-05 模型保存于加载 什么是模型保存与加载 模型保存于加载API 案例1:模型保存/加载 读取数据 文件读取机制 文件读取API 案例2:CSV文件读取 图片文件读取API 案例3 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- tensorflow实现线性回归、以及模型保存与加载
内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...
- Flutter学习笔记(19)--加载本地图片
如需转载,请注明出处:Flutter学习笔记(19)--加载本地图片 上一篇博客正好用到了本地的图片,记录一下用法: 首先新建一个文件夹,这个文件夹要跟目录下 然后在pubspec.yaml里面声明出 ...
- [置顶] iOS学习笔记47——图片异步加载之EGOImageLoading
上次在<iOS学习笔记46——图片异步加载之SDWebImage>中介绍过一个开源的图片异步加载库,今天来介绍另外一个功能类似的EGOImageLoading,看名字知道,之前的一篇学习笔 ...
- sklearn模型保存与加载
sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...
- Tensorflow学习笔记----模型的保存和读取(4)
一.模型的保存:tf.train.Saver类中的save TensorFlow提供了一个一个API来保存和还原一个模型,即tf.train.Saver类.以下代码为保存TensorFlow计算图的方 ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- tensorflow学习笔记1:导出和加载模型
用一个非常简单的例子学习导出和加载模型: 导出 写一个y=a*x+b的运算,然后保存graph: import tensorflow as tf from tensorflow.python.fram ...
随机推荐
- ubuntu server 网速测试
ubuntu server 网速测试 speedtest-cli是一个用于测试网络带宽的命令行工具,可以快速测量下载和上传速度.你可以按照以下步骤安装和使用它: 打开终端. 安装speedtest-c ...
- session 和 cookie 有什么区别?
a.存储位置不同:session 存储在服务器端:cookie 存储在浏览器端. b.安全性不同:cookie 安全性一般,在浏览器存储,可以被伪造和修改. c.容量和个数限制:cookie 有容量限 ...
- OAuth + Security - 2 - 资源服务器配置
PS:此文章为系列文章,建议从第一篇开始阅读. 资源服务器配置 @EnableResourceServer 注解到一个@Configuration配置类上,并且必须使用ResourceServerCo ...
- mysql连接超时的属性设置
mysql连接超时的属性设置 2022-10-26 11:09:54.128 [http-nio-6788-exec-5] ERROR o.s.t.i.TransactionAspectSupport ...
- RabbitMQ 3.7.9版本中,Create Channel超时的常见原因及排查方法
在RabbitMQ 3.7.9版本中,Create Channel超时的常见原因及排查方法如下: 常见原因 网络问题: 网络延迟或不稳定可能导致通信超时. 网络分区(network partition ...
- Vue2 整理(一):基础篇
前言 首先说明:要直接上手简单得很,看官网熟悉大概有哪些东西.怎么用的,然后简单练一下就可以做出程序来了,最多两天,无论Vue2还是Vue3,就都完全可以了,Vue3就是比Vue2多了一些东西而已,所 ...
- Maven的依赖详解和打包方式
设置maven maven下载与安装教程: https://blog.csdn.net/YOL888666/article/details/122008374 1. 在File->setting ...
- yb课堂实战之轮播图接口引入本地缓存 《二十一》
轮播图接口引入缓存 CacheKeyManager.java package net.ybclass.online_ybclass.config; /** * 缓存key管理类 */ public c ...
- Oracle 日期减年数、两日期相减
-- 日期减年数 SELECT add_months(DEF_DATE,12*USEFUL_LIFE) FROM S_USER --两日期相减 SELECT round(sysdate-PEI.STA ...
- mybatis 逆行工程 附源码
导读 逆向工程说白了,就可以简化开发工作量,自动生成一些死板的东西,比如POJO.映射文件等等,然后在将代码拷贝至实际工程,直接拿来用! 项目结构 GeneratorSqlMap.java impor ...