相关文章:

【一】tensorflow安装、常用python镜像源、tensorflow 深度学习强化学习教学

【二】tensorflow调试报错、tensorflow 深度学习强化学习教学

【三】tensorboard安装、使用教学以及遇到的问题

【四】超级快速pytorch安装


trick1---实现tensorflow和pytorch迁移环境教学

  • 张量shape参数理解

shape参数的个数应为维度数,每一个参数的值代表该维度上的长度

  1. shape=(100,784)
  2. 代表该张量有两个维度,第一个维度长度为100,第二个维度长度为784,二维数组100784
  3. shape=(2,)
  4. 代表该张量有一个维度,第一个维度长度为2,一维数组12

第几个维度的长度,就是左数第几个中括号组之间的元素总数量

  1. # 例:
  2. [[[1,2,3],[4,5,6]]]
  3. # 第一个维度中只有一个元素[[1,2,3][4,5,6]],所以第一个维度长度为1
  4. # 第二个维度中有两个元素[1,2,3][4,5,6],所以第二个维度长度为2
  5. # 第三个维度中有三个元素“1,2,3”或“4,5,6”,所以第三个维度长度为3
  6. # 那么它的shape参数就是[1,2,3]
  • tf.trainable_variables(), tf.global_variables()的使用

tf.trainable_variables():

这个函数可以查看可训练的变量,参数trainable,其默认为True

  1. __init__(
  2. initial_value=None,
  3. trainable=True,
  4. collections=None,
  5. validate_shape=True,
  6. ...
  7. )

对于一些我们不需要训练的变量,将trainable设置为False,这时tf.trainable_variables() 就不会打印这些变量。

举个简单的例子,在下图中共定义了4个变量,分别是一个权重矩阵,一个偏置向量,一个学习率和计步器,其中前两项是需要训练的而后两项则不需要。

  1. w1 = tf. Variable (tf. randon_normal ([256, 2000]),'w1' )
  2. b1 = tf.get_ variable('b1', [2000])
  3. learning_ rate = tf. Variable(0.5, trainable=False)
  4. global_ step = tf. Variable(0, trainable=False)
  1. trainable_ params = tf. trainable_ variables()
  2. trainable_ params
  3. [<tf. VariableVariable:0' shape= (256,2000) dtype=float32_ ref>,
  4. <tf. Variable’ b1:0”shape= (2000,) dtype=float32_ ref>]

另一个问题就是,如果变量定义在scope域中,是否会有不同。实际上,tf.trainable_variables()是可以通过参数选定域名的,如下图所示:

  1. vith tf. variable_ scope(' var' ):
  2. w2 = tf.get. variable('w2' , [3, 3])
  3. w3 = tf.get. variable(' w3',[3, 3])

我们重新声明了两个新变量,其中w2是在‘var’中的,如果我们直接使用tf.trainable_variables(),结果如下

  1. trainable. params = tf.trainable.variables ()
  2. trainable_ params
  3. [<tf. Variable vrar/w2:0shape=(3, 3) dtype=float32_ ref>,
  4. <tf. Variablew3:0' shape=(3, 3) dtype=float32_ ref>]

如果我们只希望查看‘var’域中的变量,我们可以通过加入scope参数的方式实现:

  1. scope_ parans = tf. trainable_ variables (scope-' var' )
  2. scope par ains
  3. [<tf. Variable var/w2:0' shape=(3, 3) dtype=float32_ ref>]

tf.global_variables()

如果我希望查看全部变量,包括我的学习率等信息,可以通过tf.global_variables()来实现。效果如下:

  1. global parans = tf. global variables()
  2. global_ params
  3. [<tf. VariableVariable:0' shape=(256, 2000) dtype=float32_ ref>,
  4. <tf. Variable ' b1:0' shape= (2000,) dtype-float32_ ref>,
  5. <tf. Variable。Variable_ 1:0’shape=0 dtype=float32_ ref>,
  6. <tf. Variable' Variable_ 2:0 shape=() dtype=int32_ ref>]

这时候打印出来了4个变量,其中后两个即为trainable=False的学习率和计步器。与tf.trainable_variables()一样,tf.global_variables()也可以通过scope的参数来选定域中的变量。

  • Optimizer.minimize()与Optimizer.compute_gradients()和Optimizer.apply_gradients()的用法

Optimizer.minimize()

minimize()就是compute_gradients()和apply_gradients()这两个方法的简单组合,minimize()的源码如下:

  1. def minimize(self, loss, global_step=None, var_list=None,
  2. gate_gradients=GATE_OP, aggregation_method=None,
  3. colocate_gradients_with_ops=False, name=None,
  4. grad_loss=None):
  5. grads_and_vars = self.compute_gradients(
  6. loss, var_list=var_list, gate_gradients=gate_gradients,
  7. aggregation_method=aggregation_method,
  8. colocate_gradients_with_ops=colocate_gradients_with_ops,
  9. grad_loss=grad_loss)
  10. vars_with_grad = [v for g, v in grads_and_vars if g is not None]
  11. if not vars_with_grad:
  12. raise ValueError(
  13. "No gradients provided for any variable, check your graph for ops"
  14. " that do not support gradients, between variables %s and loss %s." %
  15. ([str(v) for _, v in grads_and_vars], loss))
  16. return self.apply_gradients(grads_and_vars, global_step=global_step,
  17. name=name)

主要的参数说明:

loss:  `Tensor` ,需要优化的损失; 
      var_list: 需要更新的变量(tf.Varialble)组成的列表或者元组,默认值为`GraphKeys.TRAINABLE_VARIABLES`,即tf.trainable_variables()

注意:

1、Optimizer.minimize(loss, var_list)中,计算loss所涉及的变量(假设为var(loss))包含在var_list中,也就是var_list中含有多余的变量,并不 影响程序的运行,而且优化过程中不改变var_list里多出变量的值;

2、若var_list中的变量个数少于var(loss),则优化过程中只会更新var_list中的那些变量的值,var(loss)里多出的变量值 并不会改变,相当于固定了网络的某一部分的参数值。

compute_gradients()和apply_gradients()

  1. compute_gradients(self, loss, var_list=None,
  2. gate_gradients=GATE_OP,
  3. aggregation_method=None,
  4. colocate_gradients_with_ops=False,
  5. grad_loss=None):

里面参数的定义与minimizer()函数里面的一致,var_list的默认值也一样。需要特殊说明的是,如果var_list里所包含的变量多于var(loss),则程序会报错。其返回值是(gradient, variable)对所组成的列表,返回的数据格式也都是“tf.Tensor”。我们可以通过变量名称的管理来过滤出里面的部分变量,以及对应的梯度。
apply_gradients()的源码如下:

  1. apply_gradients(self, grads_and_vars, global_step=None, name=None)

grads_and_vars的格式就是compute_gradients()所返回的(gradient, variable)对,当然数据类型也是“tf.Tensor”,作用是,更新grads_and_vars中variable的梯度,不在里面的变量的梯度不变。

tensorflow语法【shape、tf.trainable_variables()、Optimizer.minimize()】的更多相关文章

  1. tf.trainable_variables()

    https://blog.csdn.net/shwan_ma/article/details/78879620 一般来说,打印tensorflow变量的函数有两个:tf.trainable_varia ...

  2. tensorflow 生成随机数 tf.random_normal 和 tf.random_uniform 和 tf.truncated_normal 和 tf.random_shuffle

    ____tz_zs tf.random_normal 从正态分布中输出随机值. . <span style="font-size:16px;">random_norma ...

  3. tensorflow 基本函数(1.tf.split, 2.tf.concat,3.tf.squeeze, 4.tf.less_equal, 5.tf.where, 6.tf.gather, 7.tf.cast, 8.tf.expand_dims, 9.tf.argmax, 10.tf.reshape, 11.tf.stack, 12tf.less, 13.tf.boolean_mask

    1.  tf.split(3, group, input)  # 拆分函数    3 表示的是在第三个维度上, group表示拆分的次数, input 表示输入的值 import tensorflow ...

  4. Tensorflow 学习笔记 -----tf.where

    TensorFlow函数:tf.where 在之前版本对应函数tf.select 官方解释: tf.where(input, name=None)` Returns locations of true ...

  5. 【TensorFlow基础】tf.add 和 tf.nn.bias_add 的区别

    1. tf.add(x,  y, name) Args: x: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, ...

  6. Tensorflow中的tf.argmax()函数

    转载请注明出处:http://www.cnblogs.com/willnote/p/6758953.html 官方API定义 tf.argmax(input, axis=None, name=None ...

  7. tensorflow中使用tf.variable_scope和tf.get_variable的ValueError

    ValueError: Variable conv1/weights1 already exists, disallowed. Did you mean to set reuse=True in Va ...

  8. tf.trainable_variables和tf.all_variables的对比

    tf.trainable_variables返回的是可以用来训练的变量列表 tf.all_variables返回的是所有变量的列表

  9. TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用

    一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...

  10. tensorflow笔记:使用tf来实现word2vec

    (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 (四) tensorflow笔 ...

随机推荐

  1. 聊聊大语言模型(LLM)的 10 个实际应用

    近期,苹果公司正在悄悄研究可以挑战的 OpenAI.谷歌和其他公司的 AI 工具,建立了自己的框架来创建大语言模型,并创建了一个聊天机器人服务,一些工程师称之为"Apple GPT" ...

  2. Mysql--数据的导入导出以及备份

    一.导入导出 1.1.into outfile(只导出数据) 注意:mysql 5.7+版本,secure_file_priv 的值默认为NULL,即不允许导入或导出,需在 /etc/my.cnf 添 ...

  3. #1495:非常可乐(BFS+数论)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1495 BFS解法 题目 给三个数字 s n m s=n+m s在1到100之间 就是个倒水问题可以从第 ...

  4. Codeforces Round #681 (Div. 2, based on VK Cup 2019-2020 - Final) 个人题解(A - D)

    1443A. Kids Seating 题意: 给你一个整数n,现在你需要从编号 \(1\) ~ $4 ⋅ n \(中选出\)n\(个编号使得这些编号之间\)g c d ≠ 1$ ,不能整除. 看了半 ...

  5. 快捷键:mysql + idea + 浏览器

    mysql快捷键:ctrl+r 运行查询窗口的sql语句ctrl+shift+r 只运行选中的sql语句ctrl+q 打开一个新的查询窗口ctrl+w 关闭一个查询窗口ctrl+/ 注释sql语句 c ...

  6. 【体验有奖】使用 Serverless 1 步搭建照片平台!

    实验介绍 当前,Serverless 技术已经被广泛应用,Serverless = FaaS + BssS 的概念已经深入人心.本场景由函数计算和 RDS MySQL Serverless 联合打造, ...

  7. 打破 Serverless 落地边界,阿里云 SAE 发布 5 大新特性

    微服务场景,开源自建真的最快最省最稳的?复杂性真的会成为 Kubernetes 的"致命伤"吗?企业应用容器化,一定得过 K8s 这座"独木桥"吗?Server ...

  8. mybatisplus 查询结果排除某字段实现

    数据有Test表,表里有id,name,ip_address,last_time四个字段 通常查询写法,返回结果会把id,name,ip_address,last_time四个字段都返回 public ...

  9. 二、Mycat安装

    系列导航 一.Mycat实战---为什么要用mycat 二.Mycat安装 三.mycat实验数据 四.mycat垂直分库 五.mycat水平分库 六.mycat全局自增 七.mycat-ER分片 万 ...

  10. tinymce富文本编辑器升级问题

    突然这样,之前好好地.