Tensorflow中的tf.argmax()函数
转载请注明出处:http://www.cnblogs.com/willnote/p/6758953.html
官方API定义
tf.argmax(input, axis=None, name=None, dimension=None)
Returns the index with the largest value across axes of a tensor.
Args:
- input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half.
- axis: A Tensor. Must be one of the following types: int32, int64. int32, 0 <= axis < rank(input). Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
- name: A name for the operation (optional).
Returns:
- A Tensor of type int64.
关于axis
定义中的axis与numpy中的axis是一致的,下面通过代码进行解释
import numpy as np
import tensorflow as tf
sess = tf.session()
m = sess.run(tf.truncated_normal((5,10), stddev = 0.1) )
print type(m)
print m
-------------------------------------------------------------------------------
<type 'numpy.ndarray'>
[[ 0.09957541 -0.0965599 0.06064715 -0.03011306 0.05533558 0.17263047
-0.02660419 0.08313394 -0.07225946 0.04916157]
[ 0.11304571 0.02099175 0.03591062 0.01287777 -0.11302195 0.04822164
-0.06853487 0.0800944 -0.1155676 -0.01168544]
[ 0.15760773 0.05613248 0.04839646 -0.0218203 0.02233066 0.00929849
-0.0942843 -0.05943 0.08726917 -0.059653 ]
[ 0.02553608 0.07298559 -0.06958302 0.02948747 0.00232073 0.11875584
-0.08325859 -0.06616175 0.15124641 0.09522969]
[-0.04616683 0.01816062 -0.10866459 -0.12478453 0.01195056 0.0580056
-0.08500613 0.00635608 -0.00108647 0.12054099]]
m是一个5行10列的矩阵,类型为numpy.ndarray
#使用tensorflow中的tf.argmax()
col_max = sess.run(tf.argmax(m, 0) ) #当axis=0时返回每一列的最大值的位置索引
print col_max
row_max = sess.run(tf.argmax(m, 1) ) #当axis=1时返回每一行中的最大值的位置索引
print row_max
array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
array([5, 0, 0, 8, 9])
-------------------------------------------------------------------------------
#使用numpy中的numpy.argmax
row_max = m.argmax(0)
print row_max
col_max = m.argmax(1)
print col_max
array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
array([5, 0, 0, 8, 9])
可以看到tf.argmax()与numpy.argmax()方法的用法是一致的
- axis = 0的时候返回每一列最大值的位置索引
- axis = 1的时候返回每一行最大值的位置索引
- axis = 2、3、4...,即为多维张量时,同理推断
参考
Tensorflow中的tf.argmax()函数的更多相关文章
- TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同
tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...
- tf.argmax()函数作用
tf.argmax()函数原型: def argmax(input, axis=None, name=None, dimension=None, output_type=dtypes.int64) 作 ...
- tensorflow中使用tf.variable_scope和tf.get_variable的ValueError
ValueError: Variable conv1/weights1 already exists, disallowed. Did you mean to set reuse=True in Va ...
- tf.Session()函数的参数应用(tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/dcrmg/article/details ...
- 【Tensorflow】tf.argmax函数
tf.argmax(input, axis=None, name=None, dimension=None) 此函数是对矩阵按行或列计算最大值 参数 input:输入Tensor axis:0表示 ...
- [转载]tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定
tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...
- TensorFlow 中的 tf.train.exponential_decay() 指数衰减法
exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None) 使 ...
- tensorflow中共享变量 tf.get_variable 和命名空间 tf.variable_scope
tensorflow中有很多需要变量共享的场合,比如在多个GPU上训练网络时网络参数和训练数据就需要共享. tf通过 tf.get_variable() 可以建立或者获取一个共享的变量. tf.get ...
- tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定
tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...
随机推荐
- C#获取类库(DLL)的绝对路径
C#中当我们在写公共的类库的时候难免会调用一些xml配置文件,而这个配置文件的路径则非常重要,常用的方式就是写在web.config中,而我们也可以将配置文件直接放在dll的同级目录,那么怎么获得当前 ...
- 部署Spring web项目遇到的问题及解决方案
非常悲伤的一个提示: 错误源码: Caused by: java.lang.ArrayStoreException: sun.reflect.annotation.TypeNotPresentExce ...
- 『PyTorch』第五弹_深入理解Tensor对象_上:初始化以及尺寸调整
一.创建Tensor 特殊方法: t.arange(1,6,2)t.linspace(1,10,3)t.randn(2,3) # 标准分布,*size t.randperm(5) # 随机排序,从0到 ...
- .split(",", -1);和.split(",")的区别
.split(",", -1);和.split(",")的区别在于://eg:String a="河南省,,金水区".//a.split(& ...
- mysql导入导出数据过大命令
phpmyadmin 导入或者导出都是有限制的,当导入或者导出的数据会报错. 1.导入数据库 mysql -u root -p<备份文件的保存路径 eg: mysql -u root -p &l ...
- VS2010-自定义控件
1.自定义控件 (1)新建—项目,项目模板选择“类库”,取名smControl,填写项目文件保存目录,点击确定 (2)完成后在解决方案资源管理器中删除类Class1 (3)添加“用户控件”——在解决方 ...
- HDU 3697贪心
额...大意是你可以决定什么时候选课.然后呢.每五分钟只有一次机会选.每种课限制选课时间.问你能选到的课最多有多少. 感觉一点都不水.是自己太菜了吗? #include<stdio.h> ...
- Awk 从入门到放弃(1)–学习笔记
参考:朱双印博客 1. 将test文件中的内容打印出来:vmuser@vmuser-virtual-machine:~/panzidong/awk$ echo ddd > testvmuser@ ...
- 三个安装,手机看VIP电影。写给亲爱的学习
三个安装,看VIP电影. 市场安装firefox 安装Tempermonkey 打开firefox,点击右上角的三个点,点击附加组件 继续点击浏览全部firefox附加组件 在上面的搜索框输入 tam ...
- 关于CentOS 7 下的Oracle11g的proc编译器的一些常见问题
1.proc编译器配置问题 在使用proc将.pc文件编译成.c文件时出现一堆的错误,网上的答案七杂八杂的,都没有解决我的问题. 如下是我在使用过程中的一些错误: 由于我可能比较笨,实在是受不了网上那 ...