TensorFlow 常用的函数
TensorFlow 中维护的集合列表
在一个计算图中,可以通过集合(collection
)来管理不同类别的资源。比如通过 tf.add_to_collection
函数可以将资源加入一个或多个集合中,然后通过 tf.get_collection
获取一个集合里面的所有资源(如张量,变量,或者运行TensorFlow程序所需的队列资源等等)。比如,通过 tf.add_n(tf.get_collection('losses'))
获得总损失。
集合名称 | 集合内容 | 使用场景 |
---|---|---|
tf.GraphKeys.VARIABLES |
所有变量 | 持久化 TensorFlow 模型 |
tf.GraphKeys.TRAINABLE_VARIABLES |
可学习的变量(一般指神经网络中的参数) | 模型训练、生成模型可视化内容 |
tf.GraphKeys.SUMMARIES |
日志生成相关的张量 | TensorFlow 计算可视化 |
tf.GraphKeys.QUEUE_RUNNERS |
处理输入的 QueueRunner | 输入处理 |
tf.GraphKeys.MOVING_AVERAGE_VARIABLES |
所有计算了滑动平均值的变量 | 计算变量的滑动平均值 |
- TensorFlow中的所有变量都会被自动加入
tf.GraphKeys.VARIABLES
集合中,通过tf.global_variables()
函数可以拿到当前计算图上的所有变量。拿到计算图上的所有变量有助于持久化整个计算图的运行状态。 - 当构建机器学习模型时,比如神经网络,可以通过变量声明函数中的
trainable
参数来区分需要优化的参数(比如神经网络的参数)和其他参数(比如迭代的轮数,即超参数),若trainable = True
,则此变量会被加入tf.GraphKeys.TRAINABLE_VARIABLES
集合。然后通过tf.trainable_variables
函数便可得到所有需要优化的参数。TensorFlow中提供的优化算法会将tf.GraphKeys.TRAINABLE_VARIABLES
集合中的变量作为 默认的优化对象。
示例
tf.get_collection
的第一个参数是集合的名字,第二个参数是要加入集合的内容:
def get_weight(shape, lambda1):
# 获取一层神经网络边上的权重
var = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
# 将这个权重的 L2 正则化损失加入名称为 'losses' 的集合中
tf.add_to_collection('losses',
tf.contrib.layers.l2_regularizer(lambda1)(var))
return var
变量初始化函数
神经网络中的参数是通过 TensorFlow 中的变量来组织、保存和使用的。TensorFlow 中提供了两种变量机制:tf.Variable
和 tf.get_variable
.
- 变量的类型是不可以改变的。
- 变量的维度一般是不能改变的,除非设置参数
validate_shape = False
(很少去改变它)
随机数初始化函数
函数名 | 随机数分布 | 主要参数 |
---|---|---|
tf.random_normal |
正态分布 | 平均值、标准差、取值类型 |
tf.truncated_normal |
满足正态分布的随机值,但若随机值偏离平均值超过2个标准差,则这个数会被重新随机 | 平均值、标准差、取值类型 |
tf.random_uniform |
平均分布 | 最大、最小值、取值类型 |
tf.random_gamma |
Gramma分布 | 形状参数alpha、尺度参数beta、取值类型 |
常量初始化函数
函数名 | 功能 | 示例 |
---|---|---|
tf.zeros |
产生全0的数组 | tf.zeros([2, 3],tf.int32) |
tf.ones |
产生全1的数组 | tf.ones([2, 3],tf.int32) |
tf.fill |
产生一个全部为给定数组的数组 | tf.fill([2,3], 9) |
tf.constant |
产生一个给定值的常量 | tf.constant([2,3,4]) |
tf.get_variable
变量初始化函数
初始化函数 | 功能 | 主要参数 |
---|---|---|
tf.constant_initializer |
将变量初始化为给定常数 | 常数的取值 |
tf.random_normal_initializer |
将变量初始化为满足正态分布的随机值 | 正态分布的均值和标准差 |
tf.truncated_normal_initializer |
将变量初始化为满足正态分布的随机值,但若随机值偏离平均值超过2个标准差,则这个数会被重新随机 | 正态分布的均值和标准差 |
tf.random_uniform_initializer |
将变量初始化为满足平均分布的随机值 | 最大、最小值 |
tf.uniform_unit_scaling_initializer |
将变量初始化为满足平均分布但不影响输出数量级的随机值 | factor(产生随机值时乘以的系数) |
tf.zeros_initializer |
将变量初始化为全0 | 变量维度 |
tf.ones_initializer |
将变量初始化为全1 | 变量维度 |
当 tf.get_variable
用于创建变量时,它和 tf.Variable
的功能是基本等价的。而 tf.get_variable
与 tf.Variable
的最大的区别在于指定变量名称的参数。
- 对于
tf.Variable
函数,变量名称是一个可选参数,通过name='v'
的形式给出; - 对于
tf.get_variable
函数,变量名称是一个必填的参数。tf.get_variable
函数会根据这个名字去创建或者获取变量。
详细内容见 变量管理
其他
tf.clip_by_value
函数将张量限定在一定的范围内:
import tensorflow as tf
sess = tf.InteractiveSession()
v = tf.constant([[1., 2., 3.], [4., 5., 6.]])
tf.clip_by_value(v, 2.5, 4.5).eval() # 小于2.5的数值设为2.5,大于4.5的数值设为4.5
array([[2.5, 2.5, 3. ],
[4. , 4.5, 4.5]], dtype=float32)
tf.log
对张量所有元素进行对数运算
tf.log(v).eval()
array([[0. , 0.6931472, 1.0986123],
[1.3862944, 1.609438 , 1.7917595]], dtype=float32)
tf.greater
,比较这两个张量中的每一个元素,并返回比较结果
- 输入是两个张量
- 当输入维度不一致时会进行广播(broadcasting)
v1 = tf.constant([1., 2., 3., 4.])
v2 = tf.constant([4., 3., 2., 1.])
f = tf.greater(v1, v2)
f.eval()
array([False, False, True, True])
tf.where
比较函数
函数有三个参数:
- 第一个选择条件根据,当选择条件为
True
时,会选择第二个参数中的值,否则使用第三个参数中的值:
tf.where(f, v1, v2).eval()
array([4., 3., 3., 4.], dtype=float32)
指数衰减学习率
tf.train.exponential_decay
函数指数衰减学习率。
tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)
learning_rate
:事先设定的初始学习率decay_steps
: 衰减速度,staircase = True
时代表了完整的使用一遍训练数据所需要的迭代轮数(= 总训练样本数/每个batch中的训练样本数)decay_rate
: 衰减系数staircase
: 默认为False
,此时学习率随迭代轮数的变化是连续的(指数函数);为True
时,global_step/decay_steps
会转化为整数,此时学习率便是阶梯函数
示例:
TRAINING_STEPS = 100
global_step = tf.Variable(0)
LEARNING_RATE = tf.train.exponential_decay(
0.1, global_step, 1, 0.96, staircase=True)
x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
y = tf.square(x)
train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(
y, global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(TRAINING_STEPS):
sess.run(train_op)
if i % 10 == 0:
LEARNING_RATE_value = sess.run(LEARNING_RATE)
x_value = sess.run(x)
print("After %s iteration(s): x%s is %f, learning rate is %f." %
(i + 1, i + 1, x_value, LEARNING_RATE_value))
After 1 iteration(s): x1 is 4.000000, learning rate is 0.096000.
After 11 iteration(s): x11 is 0.690561, learning rate is 0.063824.
After 21 iteration(s): x21 is 0.222583, learning rate is 0.042432.
After 31 iteration(s): x31 is 0.106405, learning rate is 0.028210.
After 41 iteration(s): x41 is 0.065548, learning rate is 0.018755.
After 51 iteration(s): x51 is 0.047625, learning rate is 0.012469.
After 61 iteration(s): x61 is 0.038558, learning rate is 0.008290.
After 71 iteration(s): x71 is 0.033523, learning rate is 0.005511.
After 81 iteration(s): x81 is 0.030553, learning rate is 0.003664.
After 91 iteration(s): x91 is 0.028727, learning rate is 0.002436.
正则化
w = tf.constant([[1., -2.], [-3, 4]])
with tf.Session() as sess:
print(sess.run(tf.contrib.layers.l1_regularizer(.5)(w))) # 0.5 为正则化权重
print(sess.run(tf.contrib.layers.l2_regularizer(.5)(w)))
5.0
7.5
滑动平均模型
滑动平均模型会将每一轮迭代得到的模型综合起来,从而使得最终得到的模型在测试数据上更加健壮(robust)。
tf.train.ExponentialMovingAverage
需要提供一个衰减率(decay)来控制模型更新的速度。
ExponentialMovingAverage 对每一个变量会维护一个影子变量(shadow variable),这个影子变量的初始值就是相应变量的初始值,而每次运行变量更新时,影子变量的值会更新为:
\]
- shadow_variable 为影子变量,
- variable 为待更新变量
- decay 为衰减率,它越大模型越趋于稳定,在实际应用中decay一般会设置为接近 1 的数。
还可以使用 num_updates
参数来动态设置decay的大小:
\]
定义变量及滑动平均类
import tensorflow as tf
# 定义一个变量用来计算滑动平均,且其初始值为0,类型必须为实数
v1 = tf.Variable(0, dtype=tf.float32)
# step变量模拟神经网络中迭代的轮数,可用于动态控制衰减率
step = tf.Variable(0, trainable=False)
# 定义一个滑动平均的类(class)。初始化时给定了衰减率为0.99和控制衰减率的变量step
ema = tf.train.ExponentialMovingAverage(0.99, step)
# 定义一个更新变量滑动平均的操作。这里需要给定一个列表,每次执行这个操作时,此列表中的变量都会被更新。
maintain_averages_op = ema.apply([v1])
查看不同迭代中变量取值的变化。
with tf.Session() as sess:
# 初始化
init_op = tf.global_variables_initializer()
sess.run(init_op)
# 通过ema.average(v1)获取滑动平均后的变量取值。在初始化之后变量v1的值和v1 的滑动平均均为0
print(sess.run([v1, ema.average(v1)]))
# 更新变量v1的取值
sess.run(tf.assign(v1, 5))
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
# 更新step和v1的取值
sess.run(tf.assign(step, 10000))
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
# 更新一次v1的滑动平均值
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.555]
[10.0, 4.60945]
- 裁剪多余维度:
tf.squeeze
TensorFlow 常用的函数的更多相关文章
- TensorFlow常用的函数
TensorFlow中维护的集合列表 在一个计算图中,可以通过集合(collection)来管理不同类别的资源.比如通过 tf.add_to_collection 函数可以将资源加入一个 或多个集合中 ...
- Tensorflow常用的函数:tf.cast
1.tf.cast(x,dtype,name) 此函数的目的是为了将x数据,准换为dtype所表示的类型,例如tf.float32,tf.bool,tf.uint8等 example: import ...
- 深度学习TensorFlow常用函数
tensorflow常用函数 TensorFlow 将图形定义转换成分布式执行的操作, 以充分利用可用的计算资源(如 CPU 或 GPU.一般你不需要显式指定使用 CPU 还是 GPU, Tensor ...
- TensorFlow常用Python扩展包
TensorFlow常用Python扩展包 TensorFlow 能够实现大部分神经网络的功能.但是,这还是不够的.对于预处理任务.序列化甚至绘图任务,还需要更多的 Python 包. 下面列出了一些 ...
- oracle(sql)基础篇系列(一)——基础select语句、常用sql函数、组函数、分组函数
花点时间整理下sql基础,温故而知新.文章的demo来自oracle自带的dept,emp,salgrade三张表.解锁scott用户,使用scott用户登录就可以看到自带的表. #使用ora ...
- php常用字符串函数小结
php内置了98个字符串函数(除了基于正则表达式的函数,正则表达式在此不在讨论范围),能够处理字符串中能遇到的每一个方面内容,本文对常用字符串函数进行简单的小结,主要包含以下8部分:1.确定字符串长度 ...
- php常用数组函数回顾一
数组对于程序开发来说是一个必不可少的工具,我根据网上的常用数组函数,结合个人的使用情况,进行数组系列的总结复习.里面当然不只是数组的基本用法,还有相似函数的不同用法的简单实例,力求用最简单的实例,记住 ...
- byte数据的常用操作函数[转发]
/// <summary> /// 本类提供了对byte数据的常用操作函数 /// </summary> public class ByteUtil { ','A','B',' ...
- WordPress主题模板层次和常用模板函数
首页: home.php index.php 文章页: single-{post_type}.php – 如果文章类型是videos(即视频),WordPress就会去查找single-videos. ...
随机推荐
- JavaScript学习 - 基础(七) - DOM event(事件)
DOM event(事件) 定义事件: // 定义事件: //方式一,直接在标签上定义事件 // 方式二: var a11 = document.getElementsByName('a11')[0] ...
- Java 线性表、栈、队列和优先队列
1.集合 2.迭代器 例子: 3.线性表 List接口继承自Collection接口,有两个具体的类ArrayList或者LinkedList来创建一个线性表 数组线性表ArrayList Linke ...
- ODPS
ODPS 功能之概述篇 原文 http://blog.aliyun.com/2962 主题 SQL 概述 ODPS是阿里云基于自有的云计算技术研发一套开放数据处理服务(Open Data Proce ...
- css 背景图片自适应元素大小
一.一种比较土的方法,<img>置于底层. 方法如下: CSS代码: HTML: <img src="背景图片路径" /> <span>字在背景 ...
- 无责任共享 Coursera、Udacity 等课程视频(转载)
转载链接:https://www.zybuluo.com/illuz/note/71868 B站计划:https://www.zybuluo.com/illuz/note/832995#cs基础课程
- 四、Logisitic Regssion练习(转载)
转载:http://www.cnblogs.com/tornadomeet/archive/2013/03/16/2963919.html 牛顿法:http://blog.csdn.net/xp215 ...
- freeRTOS中文实用教程3--中断管理之中断嵌套
1.前言 最新的 FreeRTOS 移植中允许中断嵌套.中断嵌套需要在 FreeRTOSConfig.h 中设置configKERNEL_INTERRUPT_PRIORITY 和configMAX_S ...
- Linux mmc framework2:基本组件之block
1.前言 本文主要block组件的主要流程,在介绍的过程中,将详细说明和block相关的流程,涉及到其它组件的详细流程再在相关文章中说明. 2.主要数据结构和API 2.1 struct mmc_ca ...
- 程序执行的过程分析--【sky原创】
程序执行的过程: 比如我们要执行3 + 2 程序计数器(PC) = 指令地址 指令寄存器(IR) = 正在执行的命令 累加器(AC) = 临时存储体 那么实际上执行了三条指令 每条指令 ...
- 阿里云服务器搭建FTP
操作系统:Windows Server 2008 R2企业版. 首先,创建一个用户组:ftpUsers,创建一个用户:ftpAdmin.并将ftpAdmin隶属于ftpUsers组 其次,需要安装ft ...