『TensorFlow』one_hot化标签
tf.one_hot(indices, depth):将目标序列转换成one_hot编码
tf.one_hot
(indices, depth, on_value=None, off_value=None,
axis=None, dtype=None, name=None)indices = [0, 2, -1, 1]
depth = 3
on_value = 5.0
off_value = 0.0
axis = -1
#Then output is [4 x 3]:
output =
[5.0 0.0 0.0] // one_hot(0)
[0.0 0.0 5.0] // one_hot(2)
[0.0 0.0 0.0] // one_hot(-1)
[0.0 5.0 0.0] // one_hot(1)
with tf.Session() as sess:
print(sess.run(tf.one_hot(np.array([np.array([0,1,2,3]),np.array([2,0,3,2])]),depth=4,axis=-1))) # [[[ 1. 0. 0. 0.]
# [ 0. 1. 0. 0.]
# [ 0. 0. 1. 0.]
# [ 0. 0. 0. 1.]]
# [[ 0. 0. 1. 0.]
# [ 1. 0. 0. 0.]
# [ 0. 0. 0. 1.]
# [ 0. 0. 1. 0.]]] oh = tf.one_hot(indices = [0, 2, -1, 1], depth = 3, on_value = 5.0 , off_value = 0.0, axis = -1)
sess = tf.Session()
sess.run(oh) # array([[5., 0., 0.],
# [0., 0., 5.],
# [0., 0., 0.],
# [0., 5., 0.]], dtype=float32)
另一种思路:稀疏张量构建法
import numpy as np
import tensorflow as tf NUMCLASS = 3
batch_size = 5 labels = tf.placeholder(dtype=tf.int32, shape=[batch_size, 1])
index = tf.reshape(tf.range(0, batch_size,1), [batch_size, 1])
one_hot = tf.sparse_to_dense(
tf.concat(values=[index, labels], axis=1),
[batch_size, NUMCLASS],
1.0, 0.0
)
with tf.Session() as sess:
lab = np.random.randint(0,3,[5,1])
print(sess.run(one_hot, feed_dict={labels:lab}))
print(sess.run(tf.one_hot(np.squeeze(lab),depth=3,axis=1)))
注意两种方法输入数据维度的变化(稀疏法为了得到足够的索引需要升维),结果如下:
[[ 1. 0. 0.]
[ 1. 0. 0.]
[ 0. 0. 1.]
[ 1. 0. 0.]
[ 0. 1. 0.]]
[[ 1. 0. 0.]
[ 1. 0. 0.]
[ 0. 0. 1.]
[ 1. 0. 0.]
[ 0. 1. 0.]]
『TensorFlow』one_hot化标签的更多相关文章
- 『TensorFlow』专题汇总
TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...
- 『TensorFlow』TFR数据预处理探究以及框架搭建
一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...
- 『TensorFlow』读书笔记_降噪自编码器
『TensorFlow』降噪自编码器设计 之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...
- 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍
一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...
- 『TensorFlow』分布式训练_其三_多机分布式
本节中的代码大量使用『TensorFlow』分布式训练_其一_逻辑梳理中介绍的概念,是成熟的多机分布式训练样例 一.基本概念 Cluster.Job.task概念:三者可以简单的看成是层次关系,tas ...
- 『TensorFlow』DCGAN生成动漫人物头像_下
『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...
- 『TensorFlow』滑动平均
滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...
- 『TensorFlow』流程控制
『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...
- 『TensorFlow』梯度优化相关
tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...
随机推荐
- Cookie:解决HTTP协议无保存状态
客户端 Cookie会根据从服务器端发送的相应报文内一个叫Set-Cookie的首部字段信息,通知客户端保存Cookie.当下次客户端再往该服务器发送请求时,客户端会自动在请求报文中加入Cookie值 ...
- JDBC事务(三)ThreadLocal绑定Connection
处理一个请求即开启一个线程,在三层中,执行三层中的方法都是用的同一个线程. 我们开启一个事务,使用conn.setAutoCommit(false); conn应该属于ado层,不应该出现在servi ...
- Linux Input子系统
先贴代码: //input.c int input_register_handler(struct input_handler *handler) { //此处省略很多代码 list_for_each ...
- 访问GitLab的PostgreSQL数据库,查询、修改、替换等操作
1.登陆gitlab的安装服务查看配置文件 cat /var/opt/gitlab/gitlab-rails/etc/database.yml production: adapter: postgre ...
- (转载)中文Appium API 文档
该文档是Testerhome官方翻译的源地址:https://github.com/appium/appium/tree/master/docs/cn官方网站上的:http://appium.io/s ...
- 与数论的厮守02:整数的因子分解—Pollard_Rho
学Pollard_Rho之前,你需要学会:Miller Rabin. 这是一个很高效的玄学算法,用来对大整数进行因数分解. 我们来分解n.若n是一个素数,那么就不需要分解了.所以我们还得能够判断一个数 ...
- caffe运行报错:datum channel>0(0:0)
caffe在运行的时候报错:datum channel>0(0:0) 错误原因:数据通道错误,caffe不能识别 解决方案:不告诉你
- asp.net 导出excel--NPOI
1.使用OLEDB导出Excel ,这种方式有点慢,慎用 /// <summary> /// 使用OLEDB导出Excel /// </summary> /// <par ...
- Scala 偏函数
如果你想定义一个函数,而让它只接受和处理其参数定义域范围内的子集,对于这个参数范围外的参数则抛出异常,这样的函数就是偏函数(顾名思异就是这个函数只处理传入来的部分参数). 偏函数是个特质其的类型为Pa ...
- 使用查询分析器和SQLCMD分别登录远程的SQL2005的1434端口
SQLCMD是操作SQLSERVER的一个命令行工具, 而查询分析器是它的图形工具 查询分析器(SQL2005下叫managerment studio),连接远程的SQLSERVER2005, ...