tensorflow最大的问题就是大家都讲算法,不讲解用法,API文档又全是英文的,看起来好吃力,理解又不到位。当然给数学博士看的话,就没问题的。

最近看了一系列非常不错的文章,做一下记录:

https://www.zhihu.com/people/hong-lan-99/activities

https://github.com/lanhongvp

https://blog.csdn.net/qq_37747262

https://blog.csdn.net/qq_37747262/article/details/82223155

特别是关于他的填坑记系列的。我发现我看不懂 tf.metrics.precision_at_k这个代码,在知乎上也找到了他解释的文档。奈何解释的我看不太懂。

以下是他在知乎上插入的代码

作者:洪澜
链接:https://www.zhihu.com/question/277184041/answer/480219663
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 import tensorflow as tf
import numpy as np y_true = np.array([[2], [1], [0], [3], [0], [1]]).astype(np.int64)
y_true = tf.identity(y_true) y_pred = np.array([[0.1, 0.2, 0.6, 0.1],
[0.8, 0.05, 0.1, 0.05],
[0.3, 0.4, 0.1, 0.2],
[0.6, 0.25, 0.1, 0.05],
[0.1, 0.2, 0.6, 0.1],
[0.9, 0.0, 0.03, 0.07]]).astype(np.float32)
y_pred = tf.identity(y_pred) _, m_ap = tf.metrics.sparse_average_precision_at_k(y_true, y_pred, 2) sess = tf.Session()
sess.run(tf.local_variables_initializer()) stream_vars = [i for i in tf.local_variables()]
print((sess.run(stream_vars))) tf_map = sess.run(m_ap)
print(tf_map) tmp_rank = tf.nn.top_k(y_pred,4)
print(sess.run(tmp_rank))

以下是他的解释

作者:洪澜
链接:https://www.zhihu.com/question/277184041/answer/480219663
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

  • 简单解释一下,首先y_true代表标签值(未经过one-hot)shape:(batch_size, num_labels) ,y_pred代表预测值(logit值) ,shape:(batch_size, num_classes)
  • 其次,要注意的是tf.metrics.sparse_average_precision_at_k中会采用top_k根据不同的k值对y_pred进行排序操作 ,所以tmp_rank是为了帮助大噶理解究竟y_pred在函数中进行了怎样的转换。
  • 然后,stream_vars = [i for i in tf.local_variables()]这一行是为了帮助大噶理解 tf.metrics.sparse_average_precision_at_k创建的tf.local_varibles 实际输出值,进而可以更好地理解这个函数的用法。
  • 具体看这个例子,当k=1时,只有第一个batch的预测输出是和标签匹配的 ,所以最终输出为:1/6 = 0.166666 ;当k=2时,除了第一个batch的预测输出,第三个batch的预测输出也是和标签匹配的,所以最终输出为:(1+(1/2))/6 = 0.25
======
但是还是不太理解他的解释:为什么K=1的时候回算出来匹配的是1,而k=2时算出来是0.25
经过自己调试和摸索,终于明白了作者的用意:
 
#调用以下代码
tmp_rank = tf.nn.top_k(y_pred,4)
print(sess.run(tmp_rank))
'''
就会得到类似的东西
TopKV2(values=array([[0.6 , 0.2 , 0.1 , 0.1 ],
[0.8 , 0.1 , 0.05, 0.05],
[0.4 , 0.3 , 0.2 , 0.1 ],
[0.6 , 0.25, 0.1 , 0.05],
[0.6 , 0.2 , 0.1 , 0.1 ],
[0.9 , 0.07, 0.03, 0. ]], dtype=float32), indices=array([[2, 1, 0, 3],
[0, 2, 1, 3],
[1, 0, 3, 2],
[0, 1, 2, 3],
[2, 1, 0, 3],
[0, 3, 2, 1]]))
'''

通过对tf.nn.top_k的调用以及返回的结果,可以明白函数大概有几个作用

1. 把y_pred中的数值进行了从大到小的重新排列

2. 计算得到现在位置上的数据原来所在的位置

K值的作用是指定只计算多少个。作者的解释

  • 具体看这个例子,当k=1时,只有第一个batch的预测输出是和标签匹配的 ,所以最终输出为:1/6 = 0.166666 ;当k=2时,除了第一个batch的预测输出,第三个batch的预测输出也是和标签匹配的,所以最终输出为:(1+(1/2))/6 = 0.25

就比较容易解释了。简单的理解,就是后面输出的原来的位置和y_true进行比较,如果匹配就增加一,以前是一列数据去比较,现在成了K列,然后得到的数字/K。那么这个函数这样做的目的和意义在哪里呢?

我个人的理解,如果K=1的时候,是预测概率最大的和标签匹配的概率是多少,K=2的时候,计算的是概率最大的列和其次的最大概率恰好与标签匹配的概率的平均准确度有多少。

 

tf.metrics.sparse_average_precision_at_k 和 tf.metrics.precision_at_k的自己理解的更多相关文章

  1. tf.nn.conv2d 和 tf.nn.max_pool 中 padding 分别为 'VALID' 和 'SAME' 的直觉上的经验和测试代码

    这个地方一开始是迷糊的,写代码做比较分析,总结出直觉上的经验. 某人若想看精准的解释,移步这个网址(http://blog.csdn.net/fireflychh/article/details/73 ...

  2. 深度学习原理与框架-图像补全(原理与代码) 1.tf.nn.moments(求平均值和标准差) 2.tf.control_dependencies(先执行内部操作) 3.tf.cond(判别执行前或后函数) 4.tf.nn.atrous_conv2d 5.tf.nn.conv2d_transpose(反卷积) 7.tf.train.get_checkpoint_state(判断sess是否存在

    1. tf.nn.moments(x, axes=[0, 1, 2])  # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...

  3. TF之RNN:TF的RNN中的常用的两种定义scope的方式get_variable和Variable—Jason niu

    # tensorflow中的两种定义scope(命名变量)的方式tf.get_variable和tf.Variable.Tensorflow当中有两种途径生成变量 variable import te ...

  4. 深度学习原理与框架-Tensorflow基本操作-变量常用操作 1.tf.random_normal(生成正态分布随机数) 2.tf.random_shuffle(进行洗牌操作) 3. tf.assign(赋值操作) 4.tf.convert_to_tensor(转换为tensor类型) 5.tf.add(相加操作) tf.divide(相乘操作) 6.tf.placeholder(输入数据占位

    1. 使用tf.random_normal([2, 3], mean=-1, stddev=4) 创建一个正态分布的随机数 参数说明:[2, 3]表示随机数的维度,mean表示平均值,stddev表示 ...

  5. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  6. TensorFlow 辨异 —— tf.add(a, b) 与 a+b(tf.assign 与 =)、tf.nn.bias_add 与 tf.add(转)

    1. tf.add(a, b) 与 a+b 在神经网络前向传播的过程中,经常可见如下两种形式的代码: tf.add(tf.matmul(x, w), b) tf.matmul(x, w) + b 简而 ...

  7. tensorflow中共享变量 tf.get_variable 和命名空间 tf.variable_scope

    tensorflow中有很多需要变量共享的场合,比如在多个GPU上训练网络时网络参数和训练数据就需要共享. tf通过 tf.get_variable() 可以建立或者获取一个共享的变量. tf.get ...

  8. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  9. tensorflow 基本函数(1.tf.split, 2.tf.concat,3.tf.squeeze, 4.tf.less_equal, 5.tf.where, 6.tf.gather, 7.tf.cast, 8.tf.expand_dims, 9.tf.argmax, 10.tf.reshape, 11.tf.stack, 12tf.less, 13.tf.boolean_mask

    1.  tf.split(3, group, input)  # 拆分函数    3 表示的是在第三个维度上, group表示拆分的次数, input 表示输入的值 import tensorflow ...

随机推荐

  1. C语言之const

    鱼鹰  鱼鹰谈单片机 2月19日 预计阅读时间: 5 分钟 我们知道,数据分为两种,一种为只读,一种为可读可写,为了防止一些不变的数据被程序意外的修改,有必要对它进行保护.这就是 const 的作用. ...

  2. 【git】git中使用https和ssh协议的区别以及它们的用法

    git可以使用四种主要的协议来传输资料: 本地协议(Local),HTTP 协议,SSH(Secure Shell)协议及 git 协议.其中,本地协议由于目前大都是进行远程开发和共享代码所以一般不常 ...

  3. removeClass([class|fn])

    removeClass([class|fn]) 概述 从所有匹配的元素中删除全部或者指定的类.直线电机生产厂家   参数 classStringV1.0 一个或多个要删除的CSS类名,请用空格分开 f ...

  4. PHP mysqli_query() 函数

    PHP mysqli_query() 函数 定义和用法 mysqli_query() 函数执行某个针对数据库的查询. mysqli_query(connection,query,resultmode) ...

  5. vue项目实现详情页后退缓存之前的数据

    vue项目实现详情页后退缓存之前的数据 2019年02月19日 14:54:57 不想写代码的程序员 阅读数:244   一.需要缓存的内容: 1.后退缓存条件查询的数据 2.后退缓存分页信息 二.实 ...

  6. logstash6.5.4同步mysql数据到elasticsearch 6.4.1

    下载logstash-6.5.4 ZIP解压和es 放到es根目录下 下载mysql jdbc的驱动 mysql-connector-java-8.0.12 放在任意目录下 以下方式采用动态模板,还有 ...

  7. MIME协议(一) -- RFC822邮件格式

    MIME协议(一) -- RFC822邮件格式 .   如同其他各种电子文档一样,电子邮件内容也必须遵循一定的格式要求,各种邮件处理程序才能从中分析和提取出发件人.收件人.主题和附件等信息.邮件内容的 ...

  8. find命令计算代码行数

    [anonymous@localhost ~/lvs/ipvsadm- -regex '.*Makefile.*' -o -regex '.*\.[ch]' -exec cat {} \; | wc ...

  9. HTTP之Cookie和Session

    1. Cookie 1.1 为什么需要 Cookie? HTTP 协议是一种无状态的协议,也就是说,当前的 HTTP 请求与以前的 HTTP 请求没有任何联系.显然,这种无状态的情形在某些时候将让用户 ...

  10. LeetCode 215. 数组中的第K个最大元素(Kth Largest Element in an Array)

    题目描述 在未排序的数组中找到第 k 个最大的元素.请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素. 示例 1: 输入: [3,2,1,5,6,4] 和 k = 2 ...