模型的存储与加载

TF的API提供了两种方式来存储和加载模型:

1.生成检查点文件,扩展名.ckpt,通过在tf.train.Saver()对象上调用Saver.save()生成。包含权重和其他在程序中定义的变量,不包含图结构。

2.生成图协议文件,扩展名.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。

模型的存储与加载

https://github.com/nlintz/TensorFlow-Tutorials/blob/master/10_save_restore_net.py)

加载数据及定义模型

#加载数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels X = tf.placeholder("float", [None, 784])
Y = tf.placeholder("float", [None, 10]) #初始化权重参数
w_h = init_weights([784, 625])
w_h2 = init_weights([625, 625])
w_o = init_weights([625, 10]) #定义权重函数
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev=0.01)) #定义模型
def model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden): # this network is the same as the previous one except with an extra hidden layer + dropout
#第一个全连接层
X = tf.nn.dropout(X, p_keep_input)
h = tf.nn.relu(tf.matmul(X, w_h)) h = tf.nn.dropout(h, p_keep_hidden)
#第一个全连接层
h2 = tf.nn.relu(tf.matmul(h, w_h2)) h2 = tf.nn.dropout(h2, p_keep_hidden) return tf.matmul(h2, w_o)#输出预测值

生成网络模型,得到预测值,代码如下:

p_keep_input = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden)

定义损失函数:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)

训练模型及存储模型

首先定义一个存储路径:

ckpt_dir = "./ckpt_dir"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)

定义一个计数器,为训练轮数计数:

global_step = tf.Variable(0, name='global_step', trainable=False)

当定义完所有变量后,调用tf.train.Saver()来保存和提取变量:

# Call this after declaring all tf.Variables.
saver = tf.train.Saver() # This variable won't be stored, since it is declared after tf.train.Saver()
non_storable_variable = tf.Variable(777)

训练模型并存储

with tf.Session() as sess:
# you need to initialize all variables
tf.global_variables_initializer().run() start = global_step.eval() # get last global_step
print("Start from:", start) for i in range(start, 100):
for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)):
sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
p_keep_input: 0.8, p_keep_hidden: 0.5}) global_step.assign(i).eval() # set and update(eval) global_step with index, i
saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)

加载模型

如果有训练好的模型变量文件,可以用saver.restore()来进行模型加载:

# Launch the graph in a session
with tf.Session() as sess:
# you need to initialize all variables
tf.global_variables_initializer().run() ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables

图的存储与加载

当仅保存图模型时,才将图写入二进制文件中:

v=tf.Variable(0,name='my_variable')
sess=tf.Session()
tf.train.write_graph(sess.gaph_def,'/tmp/tfmodel','train.pbtxt')

当读取时,又从协议文件中读取出来:

with tf.Session() as_sess:
with gfile.FastGFile("/tem/tfmodel/train.pbtxt",'rb') as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
_sess.grap.as_default()
tf.import_graph_def(graph_def,name='tfgraph')

队列和线程

队列

在TF中有两种队列,即FIFOQueue和RandomShuffleQueue.

FIFOQueue:创建一个先入先出队列

RandomShuffleQueue:创建一个随机队列

队列管理器

QueueRunner

线程和协调器

使用协调器(Coordinator)来管理线程。

加载数据

TF给出了3种方法:

1.预加载数据:在TensorFlow图中定义常量或变量来保存所有数据

2.填充数据feeding:Python产生数据,再把数据填充后端

3.从文件中读取数据:让队列管理器从文件中读取数据

预加载数据

缺点:当训练数据较大时,很消耗内存。

x1=tf.constant([2,3,4])
x2=tf.constant([2,1,4])
y=tf.add(x1,x2)

填充数据

使用sess.run()中的feed_dict参数,将Python产生的数据填充给后端。

#设计图
a1=tf.placeholder(tf.int16)
a2=tf.placeholder(tf.int16)
b=tf.add(x1,x2) #用Python产生数据
li1=[2,3,4]
li2=[2,1,4] #打开一个会话,将数据填充给后端
with tf.Session() as sess:
print(sess.run(b,feed_dict={a1:li1,a2:li2})

https://www.tensorflow.org/guide/datasets#preloaded_data)

填充的方式也有数据量大、消耗内存等缺点。这时最好用第三种,从文件读取。

填充数据

从文件中读取数据分为两个步骤:

1.把样本数据写入TFRecords二进制文件

2.再从队列中读取

TF基础4的更多相关文章

  1. TF基础3

    批标准化 批标准化(batch normalization,BN)是为了克服神经网络层数加深导致难以训练而诞生的.深度神经网络随着深度加深,收敛会越来越慢,会导致梯度弥散问题(vanishing gr ...

  2. TF基础2

    1.常用API 1.图,操作和张量 tf.Graph,tf.Operation,tf.Tensor 2.可视化 TensorBoard 3.变量作用域 在TF中有两个作用域(scope),一个是nam ...

  3. ROS tf基础使用知识

    博客参考:https://www.ncnynl.com/archives/201702/1306.html ROS与C++入门教程-tf-坐标变换 说明: 介绍在c++实现TF的坐标变换 概念: Co ...

  4. TF基础5

    卷积神经网络CNN 卷积神经网络的权值共享的网络结构显著降低了模型的复杂度,减少了权值的数量. 神经网络的基本组成包括输入层.隐藏层和输出层. 卷积神经网络的特点在于隐藏层分为卷积层和池化层. pad ...

  5. ROS探索总结(十八)——重读tf

    在之前的博客中,有讲解tf的相关内容,本篇博客重新整理了tf的介绍和学习内容,对tf的认识会更加系统. 1 tf简介 1.1 什么是tf tf是一个让用户随时间跟踪多个参考系的功能包,它使用一种树型数 ...

  6. [TF] Architecture - Computational Graphs

    阅读笔记: 仅希望对底层有一定必要的感性认识,包括一些基本核心概念. Here只关注Graph相关,因为对编程有益. TF – Kernels模块部分参见:https://mp.weixin.qq.c ...

  7. tf

    第2章 Tensorflow keras实战 2-0 写在课程之前 课程代码的Tensorflow版本 大部分代码是tensorflow2.0的 课程以tf.kerasAPI为主,因而部分代码可以在t ...

  8. Variables多种表达

    Variables:TF基础数据之一,常用于变量的训练...重要性刚学TF就知道了 1.tf.Variable() tf.Variable(initial_value=None, trainable= ...

  9. [Tensorflow] Cookbook - The Tensorflow Way

    本章介绍tf基础知识,主要包括cookbook的第一.二章节. 方针:先会用,后定制 Ref: TensorFlow 如何入门? Ref: 如何高效的学习 TensorFlow 代码? 顺便推荐该领域 ...

随机推荐

  1. MongoDB_聚合

    MongoDB提供以下聚合工具来对数据进行操作:聚合框架.MapReduce以及几个简单聚合命令:count.distinct.group 聚合框架:可以使用多个构件创建一个管道,上一个构件的结果传给 ...

  2. matlab学习-使用自带的函数

    >> %定义矩阵求最大值>> a=[1 7 3;6 2 9];>> A=max(a);>> a a = 1 7 3 6 2 9 >> A A ...

  3. Nginx+Php-fpm运行原理

    一.代理与反向代理 现实生活中的例子 1.正向代理:访问google.com 如上图,因为google被墙,我们需要vpnFQ才能访问google.com. vpn对于“我们”来说,是可以感知到的(我 ...

  4. Python爬虫4------图片爬虫

    import urllib.request import re keyname="短裙" key=urllib.request.quote(keyname) headers=(&q ...

  5. C++进阶 STL(2) 第二天 一元/二元函数对象、一元/二元谓词、stack容器、queue容器、list容器(双向链表)、set容器、对组、map容器

    01 上次课程回顾 昨天讲了三个容器 string  string是对char*进行的封装 vector 单口容器 动态数组 deque(双端队列) 函数对象/谓词: 一元函数对象: for_each ...

  6. 执行目标文件引发的问题:syntax error: word unexpected (expe...

    今天不小心把一个目标文件当成了可执行文件放到开发板上进行执行,结果出现了这样一个问题:./hello_qt: line 1: syntax error: word unexpected (expect ...

  7. 关于高校表白APP的用户模板和用户场景

      用户模板一: 用户名 小明 性别,年龄 男,20岁 用户状况 单身,在校大学生 生活爱好 喜欢打篮球,唱歌 典型场景 希望找到一个心仪的可以走到最后的姑娘 典型描述 交友 用户比例 ? 用户场景一 ...

  8. Android内存管理-SoftReference的使用

    本文介绍对象的强.软.弱和虚引用的概念.应用及其在UML中的表示. 1.对象的强.软.弱和虚引用 在JDK 1.2以前的版本中,若一个对象不被任何变量引用,那么程序就无法再使用这个对象.也就是说,只有 ...

  9. 小程序中 wx.navigateTo 页面跳转没有反应?

    页面js文件中加入 show: function () {wx.navigateTo({url: ‘/pages/show/show’})} 这个函数 目的在于要做跳转到新的页面,但是你可能会遇到一个 ...

  10. 【手势交互】6. 微动VID

    中国 天津 http://www.sharpnow.com/ 微动VID是天津锋时互动科技有限公司开发的中国Leap Motion. 它能够识别并跟踪用户手部的姿态.包含:指尖和掌心的三维空间位置:手 ...