import tensorflow as tf

a = tf.random.shuffle(tf.range(5))
a
tf.sort(a, direction='DESCENDING')
# 返回索引
tf.argsort(a, direction='DESCENDING')
idx = tf.argsort(a, direction='DESCENDING')
tf.gather(a, idx)
idx = tf.argsort(a, direction='DESCENDING')
tf.gather(a, idx)
a = tf.random.uniform([3, 3], maxval=10, dtype=tf.int32)
a
tf.sort(a)
tf.sort(a, direction='DESCENDING')
idx = tf.argsort(a)
idx
# 返回前2个值
res = tf.math.top_k(a, 2)
res
res.values
res.indices
prob = tf.constant([[0.1, 0.2, 0.7], [0.2, 0.7, 0.1]])
target = tf.constant([2, 0])
# 概率最大的索引在最前面
k_b = tf.math.top_k(prob, 3).indices
k_b
k_b = tf.transpose(k_b, [1, 0])
k_b
# 对真实值broadcast,与prod比较
target = tf.broadcast_to(target, [3, 2])
target
def accuracy(output, target, topk=(1, )):
maxk = max(topk)
batch_size = target.shape[0] pred = tf.math.top_k(output, maxk).indices
pred = tf.transpose(pred, perm=[1, 0])
target_ = tf.broadcast_to(target, pred.shape)
correct = tf.equal(pred, target_) res = []
for k in topk:
correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
correct_k = tf.reduce_sum(correct_k)
acc = float(correct_k / batch_size)
res.append(acc) return res
# 10个样本6类
output = tf.random.normal([10, 6])
# 使得所有样本的概率加起来为1
output = tf.math.softmax(output, axis=1)
# 10个样本对应的标记
target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
print(f'prob: {output.numpy()}')
pred = tf.argmax(output, axis=1)
print(f'pred: {pred.numpy()}')
print(f'label: {target.numpy()}') acc = accuracy(output, target, topk=(1, 2, 3, 4, 5, 6))
print(f'top-1-6 acc: {acc}')

吴裕雄--天生自然TensorFlow2教程:张量排序的更多相关文章

  1. 吴裕雄--天生自然TensorFlow2教程:测试(张量)- 实战

    import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets import os ...

  2. 吴裕雄--天生自然TensorFlow2教程:张量限幅

    import tensorflow as tf a = tf.range(10) a # a中小于2的元素值为2 tf.maximum(a, 2) # a中大于8的元素值为8 tf.minimum(a ...

  3. 吴裕雄--天生自然TensorFlow2教程:前向传播(张量)- 实战

    手写数字识别流程 MNIST手写数字集7000*10张图片 60k张图片训练,10k张图片测试 每张图片是28*28,如果是彩色图片是28*28*3-255表示图片的灰度值,0表示纯白,255表示纯黑 ...

  4. 吴裕雄--天生自然TensorFlow2教程:手写数字问题实战

    import tensorflow as tf from tensorflow import keras from keras import Sequential,datasets, layers, ...

  5. 吴裕雄--天生自然TensorFlow2教程:函数优化实战

    import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D def himme ...

  6. 吴裕雄--天生自然TensorFlow2教程:反向传播算法

  7. 吴裕雄--天生自然TensorFlow2教程:链式法则

    import tensorflow as tf x = tf.constant(1.) w1 = tf.constant(2.) b1 = tf.constant(1.) w2 = tf.consta ...

  8. 吴裕雄--天生自然TensorFlow2教程:多输出感知机及其梯度

    import tensorflow as tf x = tf.random.normal([2, 4]) w = tf.random.normal([4, 3]) b = tf.zeros([3]) ...

  9. 吴裕雄--天生自然TensorFlow2教程:单输出感知机及其梯度

    import tensorflow as tf x = tf.random.normal([1, 3]) w = tf.ones([3, 1]) b = tf.ones([1]) y = tf.con ...

随机推荐

  1. HDU 3065 病毒侵袭持续中 (模板题)

    病毒侵袭持续中 Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)Total Sub ...

  2. SpringData JPA使用JPQL的方式查询和使用SQL语句查询

    使用Spring Data JPA提供的查询方法已经可以解决大部分的应用场景,但是对于某些业务来说,我们还需要灵活的构造查询条件, 这时就可以使用@Query注解,结合JPQL的语句方式完成查询 持久 ...

  3. 最新获取 QQ头像 和 昵称接口

    网上找来的测试可用... 获取QQ头像 http://q2.qlogo.cn/headimg_dl?bs=QQ号&dst_uin=QQ号&dst_uin=QQ号&;dst_ui ...

  4. Windows下使用nginx问题

    1.下载完成后,解压缩,运行cmd,使用命令进行操作,不要直接双击nginx.exe,不要直接双击nginx.exe,不要直接双击nginx.exe 一定要在dos窗口启动,不要直接双击nginx.e ...

  5. centos 禁用ip v6

    #  sysctl -w net.ipv6.conf.all.disable_ipv6=1 #  sysctl -w net.ipv6.conf.default.disable_ipv6=1 #  s ...

  6. Koa微信公众号开发

    微信开发者模式开启需要服务器域名合法并且把接口配置好,这个接口是接通的关键,接通后微信后台的菜单设置功能,客服功能会失效,需要开发者自定义菜单和智能客服界面,并且接通后可以调用微信网页内部的定位分享等 ...

  7. Spark Scheduler 模块(下)

    Scheduler 模块中最重要的两个类是 DAGScheduler 和 TaskScheduler.上篇讲了 DAGScheduler,这篇讲 TaskScheduler. TaskSchedule ...

  8. java.lang.IllegalStateException: Active Spring transaction synchronization or active JTA transaction with specified [javax.transaction.TransactionManager] required

    错误信息: java.lang.IllegalStateException: Active Spring transaction synchronization or active JTA trans ...

  9. redis学习(五)

    一.Redis 发布订阅 1.Redis 发布订阅(pub/sub)是一种消息通信模式:发送者(pub)发送消息,订阅者(sub)接收消息. 2.Redis 客户端可以订阅任意数量的频道. 比如你订阅 ...

  10. Golang的标识符命名规则

    Golang的标识符命名规则 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.关键字 1>.Go语言有25个关键字 Go语言的25个关键字如下所示: break,defau ...