TF Boys (TensorFlow Boys ) 养成记(三): TensorFlow 变量共享
上次说到了 TensorFlow 从文件读取数据,这次我们来谈一谈变量共享的问题。
为什么要共享变量?我举个简单的例子:例如,当我们研究生成对抗网络GAN的时候,判别器的任务是,如果接收到的是生成器生成的图像,判别器就尝试优化自己的网络结构来使自己输出0,如果接收到的是来自真实数据的图像,那么就尝试优化自己的网络结构来使自己输出1。也就是说,生成图像和真实图像经过判别器的时候,要共享同一套变量,所以TensorFlow引入了变量共享机制。
变量共享主要涉及到两个函数: tf.get_variable(<name>, <shape>, <initializer>) 和 tf.variable_scope(<scope_name>)。
先来看第一个函数: tf.get_variable。
tf.get_variable 和tf.Variable不同的一点是,前者拥有一个变量检查机制,会检测已经存在的变量是否设置为共享变量,如果已经存在的变量没有设置为共享变量,TensorFlow 运行到第二个拥有相同名字的变量的时候,就会报错。
例如如下代码:
def my_image_filter(input_images):
conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv1_weights")
conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases")
conv1 = tf.nn.conv2d(input_images, conv1_weights,
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv1 + conv1_biases)
有两个变量(Variables)conv1_weighs, conv1_biases和一个操作(Op)conv1,如果你直接调用两次,不会出什么问题,但是会生成两套变量;
# First call creates one set of 2 variables.
result1 = my_image_filter(image1)
# Another set of 2 variables is created in the second call.
result2 = my_image_filter(image2)
如果把 tf.Variable 改成 tf.get_variable,直接调用两次,就会出问题了:
result1 = my_image_filter(image1)
result2 = my_image_filter(image2)
# Raises ValueError(... conv1/weights already exists ...)
为了解决这个问题,TensorFlow 又提出了 tf.variable_scope 函数:它的主要作用是,在一个作用域 scope 内共享一些变量,可以有如下几种用法:
1)
with tf.variable_scope("image_filters") as scope:
result1 = my_image_filter(image1)
scope.reuse_variables() # or
#tf.get_variable_scope().reuse_variables()
result2 = my_image_filter(image2)
需要注意的是:最好不要设置 reuse 标识为 False,只在需要的时候设置 reuse 标识为 True。
2)
with tf.variable_scope("image_filters1") as scope1:
result1 = my_image_filter(image1)
with tf.variable_scope(scope1, reuse = True)
result2 = my_image_filter(image2)
通常情况下,tf.variable_scope 和 tf.name_scope 配合,能画出非常漂亮的流程图,但是他们两个之间又有着细微的差别,那就是 name_scope 只能管住操作 Ops 的名字,而管不住变量 Variables 的名字,看下例:
with tf.variable_scope("foo"):
with tf.name_scope("bar"):
v = tf.get_variable("v", [1])
x = 1.0 + v
assert v.name == "foo/v:0"
assert x.op.name == "foo/bar/add"
参考资料:
1. https://www.tensorflow.org/how_tos/variable_scope/
TF Boys (TensorFlow Boys ) 养成记(三): TensorFlow 变量共享的更多相关文章
- TF Boys (TensorFlow Boys ) 养成记(三)
上次说到了 TensorFlow 从文件读取数据,这次我们来谈一谈变量共享的问题. 为什么要共享变量?我举个简单的例子:例如,当我们研究生成对抗网络GAN的时候,判别器的任务是,如果接收到的是生成器生 ...
- TF Boys (TensorFlow Boys ) 养成记(一)
本资料是在Ubuntu14.0.4版本下进行,用来进行图像处理,所以只介绍关于图像处理部分的内容,并且默认TensorFlow已经配置好,如果没有配置好,请参考官方文档配置安装,推荐用pip安装.关于 ...
- TF Boys (TensorFlow Boys ) 养成记(一):TensorFlow 基本操作
本资料是在Ubuntu14.0.4版本下进行,用来进行图像处理,所以只介绍关于图像处理部分的内容,并且默认TensorFlow已经配置好,如果没有配置好,请参考官方文档配置安装,推荐用pip安装.关于 ...
- TF Boys (TensorFlow Boys ) 养成记(五)
有了数据,有了网络结构,下面我们就来写 cifar10 的代码. 首先处理输入,在 /home/your_name/TensorFlow/cifar10/ 下建立 cifar10_input.py,输 ...
- TF Boys (TensorFlow Boys ) 养成记(五): CIFAR10 Model 和 TensorFlow 的四种交叉熵介绍
有了数据,有了网络结构,下面我们就来写 cifar10 的代码. 首先处理输入,在 /home/your_name/TensorFlow/cifar10/ 下建立 cifar10_input.py,输 ...
- 记录:tf.saved_model 模块的简单使用(TensorFlow 模型存储与恢复)
虽然说 TensorFlow 2.0 即将问世,但是有一些模块的内容却是不大变化的.其中就有 tf.saved_model 模块,主要用于模型的存储和恢复.为了防止学习记录文件丢失或者蠢笨的脑子直接遗 ...
- TensorFlow学习笔记(三)MNIST数字识别问题
一.MNSIT数据处理 MNSIT是一个非常有名的手写体数字识别数据集.包含60000张训练图片,10000张测试图片.每张图片是28X28的数字. TonserFlow提供了一个类来处理 MNSIT ...
- Tensorflow 载入数据的三种方式
Tensorflow 数据读取有三种方式: Preloaded data: 预加载数据 Feeding: Python产生数据,再把数据喂给后端. Reading from file: 从文件中直接读 ...
- TensorFlow 初级教程(三)
TensorFlow基本操作 import os import tensorflow as tf os.environ[' # 使用TensorFlow输出Hello # 创建一个常量操作( Cons ...
随机推荐
- java实现链队列
java实现链队列的类代码: package linkqueue; public class LinkQueue { class Element { Object elem; Element next ...
- UNIX网络编程读书笔记:地址操纵函数
地址格式转换函数:它们在ASCII字符串(人们比较喜欢用的格式)与网络字节序的二进制值(此值存于套接口地址结构中)间转换地址. 1.inet_aton.inet_addr.inet_ntoa inet ...
- 算法笔记_028:字符串转换成整数(Java)
1 问题描述 输入一个由数字组成的字符串,请把它转换成整数并输出.例如,输入字符串“123”,输出整数123. 请写出一个函数实现该功能,不能使用库函数. 2 解决方案 解答本问题的基本思路:从左至右 ...
- Java线程-volatile不能保证原子性
下面是一共通过volatile实现原子性的例子: 通过建立100个线程,计算number这个变量最后的结果. package com.Sychronized; public class Volatil ...
- python selenium 自动化测试环境安装
注意:2.7和3.0版本的语法有些不一样 安装自动化测试软件 selenium(地址http://www.seleniumhq.org/download/) 安装步骤: 1.安装pythone运行环境 ...
- mysql GROUP_CONCAT 函数 将相同的键的多个单元格合并到一个单元格
mysql GROUP_CONCAT 函数 将相同的键的多个单元格合并到一个单元格 MemberID MemberName FruitName -------------- ------------- ...
- Android 打造完美的侧滑菜单/侧滑View控件
概述 Android 打造完美的侧滑菜单/侧滑View控件,完全自定义实现,支持左右两个方向弹出,代码高度简洁流畅,兼容性高,控件实用方便. 详细 代码下载:http://www.demodashi. ...
- spring各种邮件发送
参考地址一 参考地址二 参考地址三 参考地址四 Spring邮件抽象层的主要包为org.springframework.mail.它包括了发送电子邮件的主要接口MailSender,和值对象Simpl ...
- Python 3.x 连接 pymysql 数据库
首先,需要安装库: 使用 pycharm IDE,如PyCharm,可以使用 project python 安装第三方模块. [File] >> [settings] >> [ ...
- Android用http协议上传文件
http协议上传文件一般最大是2M,比较适合上传小于两M的文件 [代码] [Java]代码 001import java.io.File; 002import java.io.FileInp ...