『TensorFlow』第七弹_保存&载入会话_霸王回马

一、TensorFlow常规模型加载方法

保存模型

tf.train.Saver()类,.save(sess, ckpt文件目录)方法

参数名称 功能说明 默认值
var_list Saver中存储变量集合 全局变量集合
reshape 加载时是否恢复变量形状 True
sharded 是否将变量轮循放在所有设备上 True
max_to_keep 保留最近检查点个数 5
restore_sequentially 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 True

var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。

如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。

加载模型

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)

.meta文件保存了当前图结构

.data文件保存了当前参数名和值

.index文件保存了辅助索引信息

.data文件可以查询到参数名和参数值,使用下面的命令可以查询保存在文件中的全部变量{名:值}对,

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(os.path.join(savedir,savefile),None,True)

tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象

ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)

saver = tf.train.Saver({"v/ExponentialMovingAverage":v})

saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载

saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)

1.不加载图结构,只加载参数

由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。

'''
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg' with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
y = Net.inference_1(x, N_CLASS=5, train=False) with tf.Session() as sess:
# 程序前面得有 Variable 供 save or restore 才不报错
# 否则会提示没有可保存的变量
saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state('./model/')
img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
img = sess.run(tf.expand_dims(tf.image.resize_images(
tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0)) if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess,'./model/model.ckpt-0')
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
res = sess.run(y, feed_dict={x: img})
print(global_step,sess.run(tf.argmax(res,1)))

2.加载图结构和参数

'''
直接使用使用保存好的图
无需加载python定义的结构,直接使用节点名称加载模型
由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错
现阶段不推荐使用,以后如果理解深入了可能会找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg' ckpt = tf.train.get_checkpoint_state('./model/') # 通过检查点文件锁定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 载入图结构,保存在.meta文件中 with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path) # 载入参数,参数保存在两个文件中,不过restore会自己寻找 img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
img = sess.run(tf.image.resize_images(
tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
imgs = []
for i in range(128):
imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs})) '''
img = sess.run(tf.expand_dims(tf.image.resize_images(
tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
print(img)
imgs = []
for i in range(128):
imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
feed_dict={'Placeholder:0':img}))

注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。

3.简化版本

# 连同图结构一同加载
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path) # 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./model/')
saver.restore(sess,ckpt.model_checkpoint_path)

二、TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作

# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
# 二进制读取模型文件
with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
# 新建GraphDef文件,用于临时载入模型中的图
graph_def = tf.GraphDef()
# GraphDef加载模型中的图
graph_def.ParseFromString(f.read())
# 在空白图中加载GraphDef中的图
tf.import_graph_def(graph_def,name='')
# 在图中获取张量需要使用graph.get_tensor_by_name加张量名
# 这里的张量可以直接用于session的run方法求值了
# 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]

『TensorFlow』迁移学习_他山之石,可以攻玉

『cs231n』通过代码理解风格迁移

上面两篇都使用了二进制加载模型的方式

三、二进制模型制作

这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。

tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,我们可以将tf的图和参数文件整合进一个后缀为pb的二进制文件中,由于整合过程回将变量转化为常量,所以我们在日后读取模型文件时不能够进行训练,仅能向前传播,而且我们在保存时需要指定节点名称。

将图变量转换为常量的API:tf.graph_util.convert_variables_to_constants

转换后的graph_def对象转换为二进制数据(graph_def.SerializeToString())后,写入pb即可。

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2 saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, './tmodel/test_model.ckpt')
gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), ['add'])
with tf.gfile.GFile('./tmodel/model.pb', 'wb') as f:
f.write(gd.SerializeToString())

我们可以直接查看gd:

node {
  name: "v1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 1
          }
        }
        float_val: 1.0
      }
    }
  }
}
……
node {
  name: "add"
  op: "Add"
  input: "v1/read"
  input: "v2/read"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
library {
}

四、从图上读取张量

上面的代码实际上已经包含了本小节的内容,但是由于从图上读取特定的张量是如此的重要,所以我仍然单独的补充上这部分的内容。

无论如何,想要获取特定的张量我们必须要有张量的名称图的句柄,比如 'import/pool_3/_reshape:0' 这种,有了张量名和图,索引就很简单了。

从二进制模型加载张量

第二小节的代码很好的展示了这种情况

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # 瓶颈层输出张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' # 输入层张量名称
MODEL_DIR = './inception_dec_2015' # 模型存放文件夹
MODEL_FILE = 'tensorflow_inception_graph.pb' # 模型名 # 加载模型
# with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),'rb') as f: # 阅读器上下文
with open(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f: # 阅读器上下文
graph_def = tf.GraphDef() # 生成图
graph_def.ParseFromString(f.read()) # 图加载模型
# 加载图上节点张量(按照句柄理解)
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def( # 从图上读取张量,同时导入默认图
graph_def,
return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

从当前图中获取对应张量

这个就是很普通的情况,从我们当前操作的图中获取某个张量,用于feed啦或者用于输出等操作,API也很简单,用法如下:

g.get_tensor_by_name('import/pool_3/_reshape:0')

g表示当前图句柄,可以简单的使用 g = tf.get_default_graph() 获取。

从图中获取节点信息

有的时候我们对于模型中的节点并不够了解,此时我们可以通过图句柄来查询图的构造:

g = tf.get_default_graph()
print(g.as_graph_def().node)

这个操作将返回图的构造结构。从这里,对比前面的代码,我们也可以了解到:graph_def 实际就是图的结构信息存储形式,我们可以将之还原为图(二进制模型加载代码中展示了),也可以从图中将之提取出来(本部分代码)。

『TensorFlow』模型保存和载入方法汇总的更多相关文章

  1. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

  2. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  3. 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

  4. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  5. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  6. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  7. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  8. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  9. 『TensorFlow』TFR数据预处理探究以及框架搭建

    一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...

随机推荐

  1. (3.1)mysql备份与恢复之mysqldump

    目录: 1.单实例联系 1.1.备份单个数据库联系多种参数使用 [1]mysqldump命令备份演示 [2]查看备份文件 [3]用备份文件还原 1.2.mysqldump 各类参数释义 [1]--de ...

  2. FastDFS的单点部署

    1  安装libfastcommon 注意:在Centos7下和在Ubuntu下安装FastDFS是不同的,在Ubuntu上安装FastDFS需要安装libevent,而外Centos上安装FastD ...

  3. 2018-2019-1 20189203《Linux内核原理与分析》第四周作业

    第一部分 课本学习 内核版本号:Linux内核自2013年12月起,就以A.B.C.D的方式命名.A和B变得无关紧要,C是内核的真实版本.每一个版本的变化都会带来新的特性,如内部API的变化等,改动的 ...

  4. python_高级特征

    切片 Slice  : 取一个tuple的前三个元素,传统做法如下 : for i in range(3): dataList.append(testTuple[i]) if i == 2: prin ...

  5. poj 1164 深度优先搜索模板题

    #include<iostream> //用栈进行的解决: #include<cstdio> #include<algorithm> #include<cst ...

  6. word之选中文本

    在word和notepad中: 特别是在文件很大,如果用鼠标下滑的话,不知道会滑多久呢, 快捷键+鼠标点击截至处

  7. CentOS 7 搭建Jumpserver跳板机(堡垒机)

    跳板机概述: 跳板机就是一台服务器,开发或运维人员在维护过程中首先要统一登录到这台服务器,然后再登录到目标设备进行维护和操作 跳板机缺点:没有实现对运维人员操作行为的控制和审计,使用跳板机的过程中还是 ...

  8. 20165215 2017-2018-2《Java程序设计》课程总结

    20165215 2017-2018-2<Java程序设计>课程总结 一.每周作业链接汇总 预备作业1:我期望的师生关系:令我记忆深刻的老师,期望的师生关系,本学期的学习规划. 预备作业二 ...

  9. Axis2开发WebService客户端 的3种方式

    Axis2开发WebService客户端 的3种方式 在dos命令下   wsdl2java        -uri    wsdl的地址(网络上或者本地)   -p  com.whir.ezoffi ...

  10. Oarcle 入门之where关键字

    --where 关键字 --作用:过滤行 *将需要的行用where引导出来 用法: 1.判断符号:=,!=,<>,<,>,<=,>=,any,some,all; 例 ...