Tensorflow之TFRecord的原理和使用心得
本文首发于微信公众号「对白的算法屋」
大家好,我是对白。
目前,越来越多的互联网公司内部都有自己的一套框架去训练模型,而模型训练时需要的数据则都保存在分布式文件系统(HDFS)上。Hive作为构建在HDFS上的一个数据仓库,它本质上可以看作是一个翻译器,可以将HiveSQL语句翻译成MapReduce程序或Spark程序,因此模型需要的数据例如csv/libsvm文件都会保存成Hive表并存放在HDFS上,那么问题就来了,如何大规模地把HDFS中的数据直接喂到Tensorflow中呢?Tensorflow提供了一种解决方法:spark-tensorflow-connector,支持将spark DataFrame格式数据直接保存为TFRecords格式数据,接下来就带大家了解一下TFRecord的原理、构成和如何生成TFRecords文件。
TFRecord介绍
TFRecord是Tensorflow训练和推断标准的数据存储格式之一,将数据存储为二进制文件(二进制存储具有占用空间少,拷贝和读取(from disk)更加高效的特点),而且不需要单独的标签文件了,其本质是一行行字节字符串构成的样本数据。
一条TFRecord数据代表一个Example,一个Example就是一个样本数据,每个Example内部由一个字典构成,每个key对应一个Feature,key为字段名,Feature为字段名所对应的数据,Feature有三种数据类型:ByteList、FloatList,Int64List。
TFRecord构成
它实质上是由protobuf定义的一种数据协议,其中tensorflow提供了两种Example表示形式 Example和SequenceExample。它的定义代码位于[tensroflow/core/example/example.proto & feature.proto]。
Example和SequenceExample的定义:
message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
// Containers for sequential data.
//
// A FeatureList contains lists of Features. These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name. The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};
我们这里以最常用的Example来进行解释。从图中可以看到,在样本生产环节,每个Example内部由一个dict构成,每个key(string)对应着一个Feature结构,这个Feature结构有三种具体形式,分别是ByteList,FloatList,Int64List三种。这三种形式便可以承载string,bytes,float,double,int,long等多种样本结构,并且基于list的表示,使得我们既可以表达scalar,也可以表达vector类型的数据(注意如果想要将一个matrix保存到到一个Feature内,其值需要时按照Row-Major拍平的1-D array, 行列数据需使用额外字段保存,方便反序列化)。这里需要注意的是,我们在序列化的时候,并未将格式信息序列化进去,实质上,序列化后的,每条tfrecord中的数据,只具有以下数据:
TFRecord中每条数据的格式:
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
因此我们可以看出来,TFRecord并不是一个self-describing的格式,也就是说,tfrecord的write和read都需要额外指明schema。从上图我们也能看出来,在实际训练的时候,样本都需要经过一个知晓了Schema的Parser来进行解析,然后才能传递给Tensorflow进行实际的训练。
注:这里只展示了CTR场景常使用的Example,当然也有图像等场景需要使用SequenceExample进行一些样本的结构化表达,这里不做展开。根据官方文档来看,SequenceExample主要是使用在时序特征和视频特征。其中context字段描述的是和当期时间和特征不相关的共性数据,而feature_list则持有和时间或者视频帧相关的数据。感兴趣可以参考youtube-8M这个数据集中关于样本数据的表示。
TFRecord的生成(小规模)
TFRecord的生成=Example序列化+写入TFRecord文件
构建Example时需要指定格式信息(字典)key是特征,value是BytesList/FloatList/Int64List值,但Example序列化时并未将格式信息序列化进去,因此读取TFRecord文件需要额外指明schema。
每个Example会序列化成字节字符串并写入TFRecord文件中,代码如下:
import tensorflow as tf
# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _int64list_feature(value_list):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))
# Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
# 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
feature = {
'user_id': _int64_feature(user_id),
'city_id': _int64_feature(city_id),
'app_type': _int64_feature(app_type),
'viewd_pois': _int64list_feature(viewd_pois),
'avg_paid': _float_feature(avg_paid),
'comment': _bytes_feature(comment),
}
# 调用相关api将Example序列化为字节字符串
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
# 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
with tf.python_io.TFRecordWriter(filepath) as writer:
writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
print "write demo data done."
filepath = "testdata.tfrecord"
write_demo(filepath)
由以上代码可知,TFRecord的原理是:将每个样本传给serialize_example函数并输出字节字符串,再通过TFRecordWriter类写入TFRecord文件中,有多少个样本就会生成多少个字节字符串。
TFRecord的生成(大规模)
TFRecord的生成=spark DataFrame格式数据保存为tfrecords格式数据
from pyspark.sql.types import *
def main():
#从hive表中读取数据
df=spark.sql("""
select * from experiment.table""")
#tfrecords保存路径
path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
#将spark DataFrame格式数据转换为tfrecords格式数据
df.repartition(file_num).write \
.mode("overwrite") \
.format("tfrecords") \
.option("recordType", "Example")\
.save(path)
if __name__ == "__main__":
main()
TFRecord的读取
在模型训练的时候需要读取TFRecord文件,有三个步骤:
1、首先通过tf.data.TFRecordDataset() API读取TFRecord文件并创建dataset;
2、定义schema;
3、使用tf.parse_single_example() 按照schema解析dataset中每个样本;
schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。
Tensorflow提供了三种解析函数:
1、tf.FixedLenFeature(shape,dtype,default_value):解析定长特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;
2、tf.VarLenFeature(dtype):解析变长特征,dtype:输入数据类型;
3、tf.FixedSequenceFeature(shape,dtype,default_value):解析定长序列特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;
代码如下:
def read_demo(filepath):
# 定义schema
schema = {
'user_id': tf.FixedLenFeature([], tf.int64),
'city_id': tf.FixedLenFeature([], tf.int64),
'app_type': tf.FixedLenFeature([], tf.int64),
'viewed_pois': tf.VarLenFeature(tf.int64),
'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
'comment': tf.FixedLenFeature([], tf.string, default_value=''),
}
# 使用相关api,按照schema解析dataset中的样本
def _parse_function(example_proto):
return tf.parse_single_example(example_proto, schema)
# 读取TFRecord文件来创建dataset
dataset = tf.data.TFRecordDataset(filepath)
#按照schema解析dataset中的每个样本
parsed_dataset = dataset.map(_parse_function)
#创建Iterator并迭代Iterator即可访问dataset中的样本
next = parsed_dataset.make_one_shot_iterator().get_next()
# 这里直接利用session,打印dataset中的样本
with tf.Session() as sess:
while True:
try:
print sess.run(next)
except:
print "out of data"
break
其中,
tf.parse_single_example(
serialized,
features,
name=None,
example_names=None
)
参数:
- serialized:序列化的Example。
- features:一个字典,key是特征,value是FixedLenFeature/VarLenFeature/FixedSequenceFeature值。
- name:此操作的名称(可选)。
- example_names:(可选)标量字符串张量,关联的名称。
返回:
一个字典,key是特征,value是Tensor或Sparse Tensor值。
最后欢迎大家关注我的微信公众号:对白的算法屋(duibainotes),跟踪NLP、推荐系统和对比学习等机器学习领域前沿。
想进一步交流的同学也可以通过公众号加我的微信一同探讨技术问题,谢谢。
推荐阅读
Tensorflow之TFRecord的原理和使用心得的更多相关文章
- 4. Tensorflow的Estimator实践原理
1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...
- 3. Tensorflow生成TFRecord
1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...
- tensorflow核心概念和原理介绍
关于 TensorFlow TensorFlow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库. 节点(Nodes)在图中表示数学操作,图中的线(edges)则表示 ...
- Tensorflow 读写 tfrecord 文件(Python3)
TensorFlow笔记博客:https://blog.csdn.net/xierhacker/article/category/6511974 写入tfrecord文件 import tensorf ...
- tensorflow制作tfrecord格式数据
tf.Example msg tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据. To ...
- 《转》从系统和代码实现角度解析TensorFlow的内部实现原理 | 深度
from https://www.leiphone.com/news/201702/n0uj58iHaNpW9RJG.html?viewType=weixin 摘要 2015年11月9日,Google ...
- tensorflow的tfrecord操作代码与数据协议规范
tensorflow的数据集可以说是非常重要的部分,我认为人工智能就是数据加算法,数据没处理好哪来的算法? 对此tensorflow有一个专门管理数据集的方式tfrecord·在训练数据时提取图片与标 ...
- 移动端开发rem布局之less+媒体查询布局的原理步骤和心得
rem即是以html文件中font-size的大小的倍数rem布局的原理:通过媒体查询设置不同屏幕宽度下的html的font-size大小,然后在css布局时用rem单位取代px,从而实现页面元素大小 ...
- 分享《机器学习实战基于Scikit-Learn和TensorFlow》中英文PDF源代码+《深度学习之TensorFlow入门原理与进阶实战》PDF+源代码
下载:https://pan.baidu.com/s/1qKaDd9PSUUGbBQNB3tkDzw <机器学习实战:基于Scikit-Learn和TensorFlow>高清中文版PDF+ ...
随机推荐
- 构建前端第13篇之---VUE的method:{}的括号未括到方法导致 _vm.linkProps is not a function
- 创建函数function
1.创建普通函数 function 函数名称(){ 函数体://封装的代码 } 函数名称()://调用函数 function getSum(){ for(var i=1,sum=0;i<=100 ...
- Java中Arrays数组的定义与使用
初始化 Java中数组是固定长度,数组变量是个对象. NullPointerException 空指针异常. ArrayIndexOutOfBoundsException 索引值越界. 数组三种初始化 ...
- MIPS Pwn赛题学习
MIPS Pwn writeup Mplogin 静态分析 mips pwn入门题. mips pwn查找gadget使用IDA mipsrop这个插件,兼容IDA 6.x和IDA 7.x,在ID ...
- 网络图片下载,commons IO包
网络图片下载,commons IO包 导入commons-io包 commons-io包下载路径: http://www.apache.org/ 项目下新建lib包,将lib包添加为库 实现多线程下载 ...
- shell脚本——awk
目录 一.awk 1.1.awk简介 1.2.基本格式 1.3.工作原理 1.4.常见的内建变量(可直接用) 按字段输出文本 1.5.awk和getline 有重定向符 无重定向符 1.6.指定分隔符 ...
- antd+vue3实现动态表单的自动校验
由于vue3用的人还不多,所以有些问题博主踩了坑只能自己爬出来了,特此做个记录.如有错误,请大家指正. 回归正题,我所做的业务是,动态添加表单项,对每一项单独做校验,效果如下: 主要代码如下: 1 & ...
- 关于knn算法的总结思考
更多的关于k近邻算法的思考 KNN(K- Nearest Neighbor)法即K最邻近法,数据挖掘分类技术中最简单的方法之一 对k近邻算法的总结: 优点部分 其可以解决分类问题,同时可以天然的解决多 ...
- systemd.service — 服务单元配置
转载:http://www.jinbuguo.com/systemd/systemd.service.html 名称 systemd.service - 服务单元配置 大纲 service.servi ...
- 初探 Python Flask+Jinja2 SSTI
初探 Python Flask+Jinja2 SSTI 文章首发安全客:https://www.anquanke.com/post/id/226900 SSTI简介 SSTI主要是因为某些语言的框架中 ...