Tensorflow中的变量
从初识tf开始,变量这个名词就一直都很重要,因为深度模型往往所要获得的就是通过参数和函数对某一或某些具体事物的抽象表达。而那些未知的数据需要通过学习而获得,在学习的过程中它们不断变化着,最终收敛达到较好的表达能力,因此它们无疑是变量。
正如三位大牛所言:深度学习是一种多层表示学习方法,用简单的非线性模块构建而成,这些模块将上一层表示转化成更高层、更抽象的表示。
原文如下: Deep-learning methods are representation-learning methods with multiple levels of representation, obtained by composing simple but non-linear modules that each transform the representation at one level (starting with the raw input) into a representation at a higher, slightly more abstract level.
必读文献之一:Deep Learning
当训练模型时,用变量来存储和更新参数。变量包含张量 (Tensor)存放于内存的缓存区。建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘。这些变量的值可在之后模型训练和分析是被加载。
通过之前的学习,可以例举出以下tf的函数:
var = tf.get_variable(name, shape, initializer=initializer)
global_step = tf.Variable(0, trainable=False)
init = tf.initialize_all_variables()#高版本tf已经舍弃该函数,改用global_variables_initializer()
saver = tf.train.Saver(tf.global_variables())
initial = tf.constant(0.1, shape=shape)
initial = tf.truncated_normal(shape, stddev=0.1)
tf.global_variables_initializer()
上述函数都和tf的参数有关,主要包含在以下两类中:
从变量存在的整个过程来看上述两类:变量的创建、初始化、更新、保存和加载。
- 创建
当创建一个变量时,将一个张量
作为初始值传入构造函数Variable()
。tf提供了一系列操作符来初始化张量,初始值是常量或是随机值。注意,所有这些操作符都需要你指定张量的shape。变量的shape通常是固定的,但TensorFlow提供了高级的机制来重新调整其行列数。
可以创建以下类型的变量:常数、序列、随机数。例如:
#-*-coding:utf-8-*-
#创建常数变量的例子
import tensorflow as tf
#常数constant
tensor=tf.constant([[1,3,5],[8,0,7]])
#创建tensor值为0的变量
x = tf.zeros([3,4])
#创建tensor值为1的变量
x1 = tf.ones([3,4])
#创建shape和tensor一样的但是值全为0的变量
y = tf.zeros_like(tensor)
#创建shape和tensor一样的但是值全为1的变量
y1 = tf.ones_like(tensor)
#用8填充shape为2*3的tensor变量
z = tf.fill([2,3],8)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print (sess.run(x))
print (sess.run(y))
print (sess.run(tensor))
print (sess.run(x1))
print (sess.run(y1))
print (sess.run(z))
#-*-coding:utf-8-*-
#创建数字序列变量的例子
import tensorflow as tf x=tf.linspace(10.0, 15.0, 3, name="linspace")
y=tf.lin_space(10.0, 15.0, 3)
w=tf.range(8.0, 13.0, 2.0)
z=tf.range(3, -3, -2)
sess = tf.Session()
sess.run(tf.global_variables_initializer()) print (sess.run(x))
print (sess.run(y))
print (sess.run(w))
print (sess.run(z))
随机常量的创建详见tensorflow随机张量创建
#创建随机变量的例子
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
name="weights")
- 初始化
变量的初始化必须在模型的其它操作运行之前先明确地完成。最简单的方法就是添加一个给所有变量初始化的操作,并在使用模型之前首先运行那个操作。使用tf.global_variables_initializer()添加一个操作对变量做初始化。例如:
# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")
...
# Add an op to initialize the variables.
init = tf.global_variables_initializer() # Later, when launching the model
with tf.Session() as sess:
# Run the init operation.
sess.run(init)
...
# Use the model
...
有时候会需要用另一个变量的初始化值给当前变量初始化。由于tf.global_variables_initializer()是并行地初始化所有变量,所以用其它变量的值初始化一个新的变量时,使用其它变量的initialized_value()
属性。你可以直接把已初始化的值作为新变量的初始值,或者把它当做tensor计算得到一个值赋予新变量。例如:
# Create a variable with a random value.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
name="weights")
# Create another variable with the same value as 'weights'.
w2 = tf.Variable(weights.initialized_value(), name="w2")
# Create another variable with twice the value of 'weights'
w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")
assign()函数也有初始化的功能,详见assign()函数
另外,这里还应该说明的是还有三种读取数据的方法:Feeding、文件中读取、加载预训练数据,它们都属于给变量初始化的方式。为了不至于引起混淆,必须说明的是常量也是变量,而三种读取数据的方法,都是读取常量的方法,但依然是初始化的一种常见方式。详见Tensorflow数据读取的方式
- 更新
虽然assign()函数有对变量进行更新的作用,但是此处探讨的更新却不是如此简单。而事实上,我们不需要做什么具体的事情,因为tf是自动求导求梯度,根据代价函数自动更新参数的。这是全局参数的更新,也是tf学习的机制自动确定的。那tf如何知道哪个究竟是变量,哪个究竟又是常量呢?很简单,tf.variable()里面有个布尔型的参数trainable,表示这个参数是不是需要学习的变量,而它默认为true,因此很容易被忽略,就这样tf图会把它加入到GraphKeys.TRAINABLE_VARIABLES,从而对其进行更新。
- 保存
对于训练的变量,成功的话,都是有意义的,需要将其保存在文件里,方便以后的测试和再训练,这就是weights文件,是必不可少的。
在cifar10项目中当然也有保存这些变量,例如:
# Create a saver.
saver = tf.train.Saver(tf.global_variables())
......
# Save the model checkpoint periodically.
if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
Saver类把变量存储在二进制文件checkpoint里,主要包含从变量名到tensor值的映射关系。
- 加载
加载变量和保存变量是正反的过程,保存变量是要把模型里的变量信息保存到weights文件里,而加载变量就是要把这些有意义的变量值从weights文件加载到模型里。
同理在cifar10项目中测试训练的模型时加载了上述保存的变量,例如:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/cifar10_train/model.ckpt-0,
# extract global_step from it.
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
print('No checkpoint file found')
return
如果想选择和加载某一部分变量,则可以通过变量名索引,例如:
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...
这里my_v2就是新的变量名,而v2就是它的值。
Tensorflow中的变量的更多相关文章
- TensorFlow中的变量和常量
1.TensorFlow中的变量和常量介绍 TensorFlow中的变量: import tensorflow as tf state = tf.Variable(0,name='counter') ...
- 2、Tensorflow中的变量
2.Tensorflow中的变量注意:tf中使用 变量必须先初始化下面是一个使用变量的TF代码(含注释): # __author__ = "WSX" import tensorfl ...
- 83、Tensorflow中的变量管理
''' Created on Apr 21, 2017 @author: P0079482 ''' #如何通过tf.variable_scope函数来控制tf.ger_variable函数获取已经创建 ...
- 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量
1.tf.Variable([[1, 2]]) # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...
- TensorFlow中的变量命名以及命名空间.
What: 在Tensorflow中, 为了区别不同的变量(例如TensorBoard显示中), 会需要命名空间对不同的变量进行命名. 其中常用的两个函数为: tf.variable_scope, t ...
- tensorflow中使用变量作用域及tf.variable(),tf,getvariable()与tf.variable_scope()的用法
一 .tf.variable() 在模型中每次调用都会重建变量,使其存储相同变量而消耗内存,如: def repeat_value(): weight=tf.variable(tf.random_no ...
- tensorflow中常量(constant)、变量(Variable)、占位符(placeholder)和张量类型转换reshape()
常量 constant tf.constant()函数定义: def constant(value, dtype=None, shape=None, name="Const", v ...
- tensorflow中张量_常量_变量_占位符
1.tensor 在tensorflow中,数据是被封装在tensor对象中的.tensor是张量的意思,即包含从0到任意维度的张量.常数是0维度的张量,向量是1维度的张量,矩阵是二维度的张量,以及还 ...
- tensorflow中slim模块api介绍
tensorflow中slim模块api介绍 翻译 2017年08月29日 20:13:35 http://blog.csdn.net/guvcolie/article/details/77686 ...
随机推荐
- CentOS7安装OpenStack(Rocky版)-01.控制节点的系统环境准备
分享一下Rocky版本的OpenStack安装管理经验: OpenStack每半年左右更新一版,目前是版本是201808月发布的版本-R版(Rocky),目前版本安装方法优化较好,不过依然是比较复杂 ...
- Bootstrap 样式设计 栅格系统
.col-xs- 超小屏幕 手机 (<768px) .col-sm- 小屏幕 平板 (≥768px) .col-md- 中等屏幕 桌面显示器 (≥992px) .col-lg- 大屏幕 大桌面显 ...
- 升级framework4.0后form认证票据失效的问题
好久没来了,密码都差点忘了,顺便记录下今天配置环境碰到的小问题 网站使用的form authentication做SSO登录,登录域名使用的framework20配置环境 一个栏目升级为4.0环境后, ...
- ContentProvider示例
http://hi.baidu.com/pekdou/item/b2a070c37552af210831c678 首先,我自己是各初学者,网上一些关于ContentProvider的例子也不少,我自己 ...
- CentOS 修改时区的方法
study from https://blog.csdn.net/skh2015java/article/details/85007624 第一种 tzselect 输入命令直接选择即可 第二种,直接 ...
- dotTrace 每行执行时间和执行次数
如果代码中出现效率问题,使用dotTrace来跟踪分析代码的效率问题还是很方便的.使用dotTrace不但可以看到每一个方法被调用的次数和总时间,而且可以引入源代码,查看源代码中每一行执行的次数和时间 ...
- 使用navicat 链接数据库时乱码
在建立数据库链接时设置 高级->编码->uft-8 其他版本使用下面方法
- Python开发【第七章】:面向对象进阶
1.静态方法 通过@staticmethod装饰器即可把其装饰的方法变为一个静态方法,什么是静态方法呢?其实不难理解,普通的方法,可以在实例化后直接调用,并且在方法里可以通过self.调用实例变量或类 ...
- php学习目录
前面的话 前端工程师为什么要学习php?是因为招聘要求吗?这只是一方面 一开始,我对学习php是抵触的,毕竟javascript已经够自己喝一壶的了,再去学习php,可能让自己喝醉.但是,在学习jav ...
- Goodbye My Old Days
几天前的CTT的胸牌上印着熟悉的初中学校的名字,回想起自己早已废弃的博客,不禁感慨万分.如你所见,一位名叫supy的菜鸡OIer曾经小心翼翼地写下一篇篇文章来装点这个地方,时间是初二的ZJOID1直到 ...