简介

TensorFlow变量是表示程序处理的共享持久状态的最佳方法。

我们使用tf.Variable类操作变量。tf.Variable表示可通过其运行操作来改变其值的张量。与tf.Tensor对象不同,tf.Variable存在于的单个session.run调用的上下文之外。

在TensorFlow内部,tf.Variable会存储持久性张量。具体op允许您读取和修改此张量的值。这些修改在多个tf.Session之间是可见的,因此对于一个tf.Variable,多个工作器可以看到相同的值。

创建变量

创建变量的最佳方式是调用tf.get_variable函数,此函数要求您指定变量的名称。其它副本将使用此名称访问同一变量,以及在对模型设置检查点和导出模型时指定此变量的值。tf.get_variable还允许你重复使用先前创建的同名变量,从而轻松定义重复利用层的模型。

要使用tf.get_variable创建变量,只需提供名称和形状即可:

  1. my_variable = tf.get_variable("my_variable", [1, 2, 3])

这将创建一个名为“my_variable”的变量,该变量是形状为 [1,2,3] 的三维张量。默认情况下,此变量具有 dtype=float32,其初始值将通过 tf.glorot_uniform_initializer 随机设置。

您可以选择为 tf.get_variable 指定dtype和初始化器。例如:

  1. my_int_variable = tf.get_variable("my_int_variable", [1, 2, 3], dtype=tf.int32,initializer=tf.zeros_initializer)

TensorFlow提供了许多方便的初始化器。或者,你也可以将tf.Variable初始化为tf.Tensor的值。例如:

  1. other_variable = tf.get_variable("other_variable", dtype=tf.int32,initializer=tf.constant([23, 42]))

请注意,当初始化器是tf.Tensor时,您不应该指定变量的形状,因为将使用初始化器张量的形状。

变量集合

由于TensorFlow程序的未连接部分可能需要创建变量,因此能有一种方式访问所有变量有时十分受用。为此,TensorFlow提供了集合,它们是张量或其它对象(如tf.Variable 实例)的命名列表。

默认情况下,每个tf.Variable都放置在以下两个集合中:

  • tf.GraphKeys.GLOBAL_VARIABLES - 可以在多台设备间共享的变量
  • tf.GraphKeys.TRAINABLE_VARIABLES - TensorFlow 将计算其梯度的变量。

如果您不希望变量可训练,可以将其添加到 tf.GraphKeys.LOCAL_VARIABLES 集合中。例如,以下代码段展示了如何将名为 my_local 的变量添加到此集合中:

  1. my_local = tf.get_variable("my_local", shape=(),collections=[tf.GraphKeys.LOCAL_VARIABLES])

或者,您可以指定 trainable=False(作为 tf.get_variable 的参数):

  1. my_non_trainable = tf.get_variable("my_non_trainable",
  2. shape=(),
  3. trainable=False)

您也可以使用自己的集合。集合名称可为任何字符串,且您无需显式创建集合。创建变量(或任何其他对象)后,要将其添加到集合中,请调用 tf.add_to_collection。例如,以下代码将名为 my_local 的现有变量添加到名为 my_collection_name 的集合中:

  1. tf.add_to_collection("my_collection_name", my_local)

要检索您放置在某个集合中的所有变量(或其他对象)的列表,您可以使用:

  1. tf.get_collection("my_collection_name")

设备放置方式

与任何其它TensorFlow指令一样,您可以将变量放置在特定设备上。例如,以下代码创建了名为v的变量并将其放置在第二个GPU上:

  1. with tf.device("/device:GPU:1"):
  2. v = tf.get_variable("v", [1])

在分布式设备中,将变量放置在正确设备上尤为重要。如果不小心将变量放在工作器而不是参数服务器上,可能会严重减慢训练速度,最坏情况下,可能会让每个工作器不断复制各个变量。为此,我们提供了 tf.train.replica_device_setter,它可以自动将变量放置在参数服务器中。例如:

  1. cluster_spec = {
  2. "ps": ["ps0:2222", "ps1:2222"],
  3. "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
  4. with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
  5. v = tf.get_variable("v", shape=[20, 20]) # this variable is placed
  6. # in the parameter server
  7. # by the replica_device_setter

初始化变量

变量必须先初始化后才可使用。如果您在低级别的TensorFlow API中进行编程(即您在显式创建自己的图和会话),则必须明确初始化变量。tf.contrib.slim、tf.Estimator和Keras等大多数高级框架在训练模型时会自动为您初始化变量。

显式初始化在其它方面很有用。它允许您在从检查点重新加载模型时不用重新运行潜在资源消耗大的初始化器,并允许在分布式设备中共享随机初始化的变量时具有确定性。

要在训练前一次初始化所有可训练变量,请调用 tf.global_variables_initializer()。此函数会返回一个操作,负责初始化 tf.GraphKeys.GLOBAL_VARIABLES 集合中的所有变量。运行此操作会初始化所有变量。例如:

  1. session.run(tf.global_variables_initializer())
  2. # Now all variables are initialized.

如果您确实需要自行初始化变量,则可以运行变量的初始化器操作。例如:

  1. session.run(my_variable.initializer)

您可以查询哪些变量尚未初始化。例如,以下代码会打印出所有尚未初始化的变量名称:

  1. print(session.run(tf.report_uninitialized_variables()))

请注意,默认情况下,tf.global_variable_initializer 不会指定变量的初始化顺序。因此,如果变量的初始值取决于另一变量的值,那么很有可能会出现错误。任何时候,如果您在并非所有变量都已初始化的上下文中使用某个变量值(例如在初始化某个变量时使用另一个变量的值),最好使用variable.initialized_value(),而非variable:

  1. v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
  2. w = tf.get_variable("w", initializer=v.initialized_value() + 1)

使用变量

要在TensorFlow中使用tf.Variable的值,只需将其视为普通的tf.Tensor即可:

  1. v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
  2. w = v + 1 # w is a tf.Tensor which is computed based on the value of v.
  3. # Any time a variable is used in an expression it gets automatically
  4. # converted to a tf.Tensor representing its value.

要为变量赋值,请使用assign、assign_add方法以及tf.Variable类中的友元。例如,以下就是调用这些方法的方式:

  1. v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
  2. assignment = v.assign_add(1)
  3. tf.global_variables_initializer().run()
  4. sess.run(assignment) # or assignment.op.run(), or assignment.eval()

大多数TensorFlow优化器都有专门的op,会根据某种梯度下降算法有效的更新变量的值。请参阅 tf.train.Optimizer,了解如何使用优化器。

由于变量是可变的,因此及时了解任何时间点所使用的变量值版本有时十分有用。要在事件发生后强制重新读取变量的值,可以使用 tf.Variable.read_value。例如:

  1. v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
  2. assignment = v.assign_add(1)
  3. with tf.control_dependencies([assignment]):
  4. w = v.read_value() # w is guaranteed to reflect v's value after the
  5. # assign_add operation.

共享变量

TensorFlow支持两种共享变量的方式:

  • 显式传递tf.Variable对象
  • 将tf.Variable对象隐式封装在tf.variable_scope对象内

虽然显式传递变量的代码非常清晰,但有时编写在其实现中隐式使用变量的TensorFlow函数非常方便。tf.layers中的大多数功能层以及所有tf.metrics和部分其他库实用程序都使用这种方法。

变量作用域允许您在调用隐式创建和使用变量的函数时控制变量重用。作用域还允许您以分层和可理解的方式命名变量。

例如,假设我们编写一个函数创建一个卷积/relu层:

  1. def conv_relu(input, kernel_shape, bias_shape):
  2. # Create variable named "weights".
  3. weights = tf.get_variable("weights", kernel_shape,
  4. initializer=tf.random_normal_initializer())
  5. # Create variable named "biases".
  6. biases = tf.get_variable("biases", bias_shape,
  7. initializer=tf.constant_initializer(0.0))
  8. conv = tf.nn.conv2d(input, weights,
  9. strides=[1, 1, 1, 1], padding='SAME')
  10. return tf.nn.relu(conv + biases)

此函数使用短名称weights和biases,这有利于清晰区分二者。然而,在真实模型中,我们需要很多此类卷积层,而且重复调用此函数将不起作用:

  1. input1 = tf.random_normal([1,10,10,32])
  2. input2 = tf.random_normal([1,20,20,32])
  3. x = conv_relu(input1, kernel_shape=[5, 5, 32, 32], bias_shape=[32])
  4. x = conv_relu(x, kernel_shape=[5, 5, 32, 32], bias_shape = [32]) # This fails.

由于期望的操作不清除(创建新变量还是重新使用现有变量?),因此TensorFlow将会失败。不过,在不同作用域内调用conv_relu可表明我们想要创建新变量:

  1. def my_image_filter(input_images):
  2. with tf.variable_scope("conv1"):
  3. # Variables created here will be named "conv1/weights", "conv1/biases".
  4. relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])
  5. with tf.variable_scope("conv2"):
  6. # Variables created here will be named "conv2/weights", "conv2/biases".
  7. return conv_relu(relu1, [5, 5, 32, 32], [32])

如果您想要共享变量,有两种方法可以选择。首先,您可以使用reuse=True 创建具有相同名称的作用域:

  1. with tf.variable_scope("model"):
  2. output1 = my_image_filter(input1)
  3. with tf.variable_scope("model", reuse=True):
  4. output2 = my_image_filter(input2)

您也可以调用 scope.reuse_variables()以触发重用:

  1. with tf.variable_scope("model") as scope:
  2. output1 = my_image_filter(input1)
  3. scope.reuse_variables()
  4. output2 = my_image_filter(input2)

由于依赖于作用域的确切字符串名称可能比较危险,因此也可以根据另一作用域初始化某个变量作用域:

  1. with tf.variable_scope("model") as scope:
  2. output1 = my_image_filter(input1)
  3. with tf.variable_scope(scope, reuse=True):
  4. output2 = my_image_filter(input2)

参考链接:https://tensorflow.google.cn/guide/variables

TensorFlow低阶API(三)—— 变量的更多相关文章

  1. TensorFlow低阶API(四)—— 图和会话

    简介 TensorFlow使用数据流图将计算表示为独立的指令之间的依赖关系.这可生成低级别的编程模型,在该模型中,您首先定义数据流图,然后创建TensorFlow会话,以便在一组本地和远程设备上运行图 ...

  2. TensorFlow低阶API(一)—— 简介

    简介 本文旨在知道您使用低级别TensorFlow API(TensorFlow Core)开始编程.您可以学习执行以下操作: 管理自己的TensorFlow程序(tf.Graph)和TensorFl ...

  3. TensorFlow低阶API(二)—— 张量

    简介 正如名字所示,TensorFlow这一框架定义和运行涉及张量的计算.张量是对矢量和矩阵向潜在的更高维度的泛化.TensorFlow在内部将张量表示为基本数据类型的n维数组. 在编写TensorF ...

  4. spark streaming kafka1.4.1中的低阶api createDirectStream使用总结

    转载:http://blog.csdn.net/ligt0610/article/details/47311771 由于目前每天需要从kafka中消费20亿条左右的消息,集群压力有点大,会导致job不 ...

  5. TebsorFlow低阶API(五)—— 保存和恢复

    简介 tf.train.Saver 类提供了保存和恢复模型的方法.通过 tf.saved_model.simple_save 函数可以轻松地保存适合投入使用的模型.Estimator会自动保存和恢复 ...

  6. Tensorflow object detection API 搭建物体识别模型(三)

    三.模型训练 1)错误一: 在桌面的目标检测文件夹中打开cmd,即在路径中输入cmd后按Enter键运行.在cmd中运行命令: python /your_path/models-master/rese ...

  7. Tensorflow object detection API 搭建物体识别模型(二)

    二.数据准备 1)下载图片 图片来源于ImageNet中的鲤鱼分类,下载地址:https://pan.baidu.com/s/1Ry0ywIXVInGxeHi3uu608g 提取码: wib3 在桌面 ...

  8. TensorFlow object detection API

    cloud执行:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pet ...

  9. Tensorflow object detection API ——环境搭建与测试

    1.开发环境搭建 ①.安装Anaconda 建议选择 Anaconda3-5.0.1 版本,已经集成大多数库,并将其作为默认python版本(3.6.3),配置好环境变量(Anaconda安装则已经配 ...

随机推荐

  1. 任务44:Identity MVC: EF + Identity实现

    使用VSCode开发 Razer的智能感知不好.所以这里切换为VS2017进行开发: 新建一个Data的文件夹来存放我们的DBContext.在Data文件夹下新建: ApplicationDbCon ...

  2. Golang项目的测试实践

    Golang项目的测试实践 最近有一个项目,链路涉及了4个服务.最核心的是一个配时服务.要如何对这个项目进行测试,保证输出质量,是最近思考和实践的重点.这篇就说下最近这个实践的过程总结. 测试金字塔 ...

  3. css里关于浏览器的前缀

    今天遇到一个比较坑爹的 -moz-box-sizing: border-box; box-sizing' border-box;   一下子有点懵逼,第一个什么鬼??一查,原来是火狐浏览器的前缀.应该 ...

  4. 704. Binary Search

    Given a sorted (in ascending order) integer array nums of n elements and a target value, write a fun ...

  5. 【Codeforces Round #411 (Div. 1)】Codeforces 804C Ice cream coloring (DFS)

    传送门 分析 这道题做了好长时间,题意就很难理解. 我们注意到这句话Vertices which have the i-th (1 ≤ i ≤ m) type of ice cream form a ...

  6. CF1060E Sergey and Subway(点分治)

    给出一颗$N$个节点的树,现在我们**在原图中**每个不直接连边但是中间只间隔一个点的两个点之间连一条边. 比如**在原图中**$u$与$v$连边,$v$与$w$连边,但是$u$与$w$不连边,这时候 ...

  7. USACO Training3.2 01串 By cellur925

    题目传送门 一句话题意:求长度为n的有m个1的大小为第k个的01串. 暑假我做的时候是真·大暴力,用二进制枚举,55分,成功T掉无数点. 正解:开始可以用计数类dp来“预处理”,状态和转移都比较好想. ...

  8. 详解基于linux环境MySQL搭建与卸载

    本篇文章将从实际操作的层面,讲解基于linux环境的mysql的搭建和卸载. 1  搭建mysql 1.1  官网下载mysql压缩包 下载压缩包时,可以先把安装包下载到本地,再上传到服务器,也可以在 ...

  9. iOS 将WKWebView内的HTML打印为PDF

    使用的webview为WKWebView,核心部分代码(Swift 4): // 创建打印渲染 let printPageRenderer:PDFRender = PDFRender() // 获取渲 ...

  10. 微服务dubbo面试题

    dubbo的工作原理? dubbo支持的序列化协议? dubbo的负载均衡和高可用策略?动态代理策略? dubbo的SPI思想? 如何基于dubbo进行服务治理.服务降级.失败重试以及超时重试? du ...