本文首发于微信公众号「对白的算法屋」

大家好,我是对白。

目前,越来越多的互联网公司内部都有自己的一套框架去训练模型,而模型训练时需要的数据则都保存在分布式文件系统(HDFS)上。Hive作为构建在HDFS上的一个数据仓库,它本质上可以看作是一个翻译器,可以将HiveSQL语句翻译成MapReduce程序或Spark程序,因此模型需要的数据例如csv/libsvm文件都会保存成Hive表并存放在HDFS上,那么问题就来了,如何大规模地把HDFS中的数据直接喂到Tensorflow中呢?Tensorflow提供了一种解决方法:spark-tensorflow-connector,支持将spark DataFrame格式数据直接保存为TFRecords格式数据,接下来就带大家了解一下TFRecord的原理、构成和如何生成TFRecords文件。


TFRecord介绍

TFRecord是Tensorflow训练和推断标准的数据存储格式之一,将数据存储为二进制文件(二进制存储具有占用空间少,拷贝和读取(from disk)更加高效的特点),而且不需要单独的标签文件了,其本质是一行行字节字符串构成的样本数据。

一条TFRecord数据代表一个Example,一个Example就是一个样本数据,每个Example内部由一个字典构成,每个key对应一个Feature,key为字段名,Feature为字段名所对应的数据,Feature有三种数据类型:ByteList、FloatList,Int64List。

TFRecord构成

它实质上是由protobuf定义的一种数据协议,其中tensorflow提供了两种Example表示形式 Example和SequenceExample。它的定义代码位于[tensroflow/core/example/example.proto & feature.proto]。

Example和SequenceExample的定义:

message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
}; message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
}; // Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
}; // Containers for sequential data.
//
// A FeatureList contains lists of Features. These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name. The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
repeated Feature feature = 1;
}; message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};

我们这里以最常用的Example来进行解释。从图中可以看到,在样本生产环节,每个Example内部由一个dict构成,每个key(string)对应着一个Feature结构,这个Feature结构有三种具体形式,分别是ByteList,FloatList,Int64List三种。这三种形式便可以承载string,bytes,float,double,int,long等多种样本结构,并且基于list的表示,使得我们既可以表达scalar,也可以表达vector类型的数据(注意如果想要将一个matrix保存到到一个Feature内,其值需要时按照Row-Major拍平的1-D array, 行列数据需使用额外字段保存,方便反序列化)。这里需要注意的是,我们在序列化的时候,并未将格式信息序列化进去,实质上,序列化后的,每条tfrecord中的数据,只具有以下数据:

TFRecord中每条数据的格式:

uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data

因此我们可以看出来,TFRecord并不是一个self-describing的格式,也就是说,tfrecord的write和read都需要额外指明schema。从上图我们也能看出来,在实际训练的时候,样本都需要经过一个知晓了Schema的Parser来进行解析,然后才能传递给Tensorflow进行实际的训练。

注:这里只展示了CTR场景常使用的Example,当然也有图像等场景需要使用SequenceExample进行一些样本的结构化表达,这里不做展开。根据官方文档来看,SequenceExample主要是使用在时序特征和视频特征。其中context字段描述的是和当期时间和特征不相关的共性数据,而feature_list则持有和时间或者视频帧相关的数据。感兴趣可以参考youtube-8M这个数据集中关于样本数据的表示。

TFRecord的生成(小规模)

TFRecord的生成=Example序列化+写入TFRecord文件

构建Example时需要指定格式信息(字典)key是特征,value是BytesList/FloatList/Int64List值,但Example序列化时并未将格式信息序列化进去,因此读取TFRecord文件需要额外指明schema。

每个Example会序列化成字节字符串并写入TFRecord文件中,代码如下:

import tensorflow as tf

# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _int64list_feature(value_list):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list)) # Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
# 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
feature = {
'user_id': _int64_feature(user_id),
'city_id': _int64_feature(city_id),
'app_type': _int64_feature(app_type),
'viewd_pois': _int64list_feature(viewd_pois),
'avg_paid': _float_feature(avg_paid),
'comment': _bytes_feature(comment),
}
# 调用相关api将Example序列化为字节字符串
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString() # 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
with tf.python_io.TFRecordWriter(filepath) as writer:
writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
print "write demo data done." filepath = "testdata.tfrecord"
write_demo(filepath)

由以上代码可知,TFRecord的原理是:将每个样本传给serialize_example函数并输出字节字符串,再通过TFRecordWriter类写入TFRecord文件中,有多少个样本就会生成多少个字节字符串。

TFRecord的生成(大规模)

TFRecord的生成=spark DataFrame格式数据保存为tfrecords格式数据

from pyspark.sql.types import *
def main():
#从hive表中读取数据
df=spark.sql("""
select * from experiment.table""")
#tfrecords保存路径
path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
#将spark DataFrame格式数据转换为tfrecords格式数据
df.repartition(file_num).write \
.mode("overwrite") \
.format("tfrecords") \
.option("recordType", "Example")\
.save(path)
if __name__ == "__main__":
main()

TFRecord的读取

在模型训练的时候需要读取TFRecord文件,有三个步骤:

1、首先通过tf.data.TFRecordDataset() API读取TFRecord文件并创建dataset;

2、定义schema;

3、使用tf.parse_single_example() 按照schema解析dataset中每个样本;

schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。

Tensorflow提供了三种解析函数:

1、tf.FixedLenFeature(shape,dtype,default_value):解析定长特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

2、tf.VarLenFeature(dtype):解析变长特征,dtype:输入数据类型;

3、tf.FixedSequenceFeature(shape,dtype,default_value):解析定长序列特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

代码如下:

def read_demo(filepath):
# 定义schema
schema = {
'user_id': tf.FixedLenFeature([], tf.int64),
'city_id': tf.FixedLenFeature([], tf.int64),
'app_type': tf.FixedLenFeature([], tf.int64),
'viewed_pois': tf.VarLenFeature(tf.int64),
'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
'comment': tf.FixedLenFeature([], tf.string, default_value=''),
} # 使用相关api,按照schema解析dataset中的样本
def _parse_function(example_proto):
return tf.parse_single_example(example_proto, schema) # 读取TFRecord文件来创建dataset
dataset = tf.data.TFRecordDataset(filepath)
#按照schema解析dataset中的每个样本
parsed_dataset = dataset.map(_parse_function)
#创建Iterator并迭代Iterator即可访问dataset中的样本
next = parsed_dataset.make_one_shot_iterator().get_next() # 这里直接利用session,打印dataset中的样本
with tf.Session() as sess:
while True:
try:
print sess.run(next)
except:
print "out of data"
break

其中,

tf.parse_single_example(
serialized,
features,
name=None,
example_names=None
)

参数:

  • serialized:序列化的Example。
  • features:一个字典,key是特征,value是FixedLenFeature/VarLenFeature/FixedSequenceFeature值。
  • name:此操作的名称(可选)。
  • example_names:(可选)标量字符串张量,关联的名称。

返回:

一个字典,key是特征,value是Tensor或Sparse Tensor值。

最后欢迎大家关注我的微信公众号:对白的算法屋(duibainotes),跟踪NLP、推荐系统和对比学习等机器学习领域前沿。

想进一步交流的同学也可以通过公众号加我的微信一同探讨技术问题,谢谢。

推荐阅读

TFRecord&tf.Example

tensorflow/ecosystem

Linkedin Spark-TFRecord

Tensorflow之TFRecord的原理和使用心得的更多相关文章

  1. 4. Tensorflow的Estimator实践原理

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  2. 3. Tensorflow生成TFRecord

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  3. tensorflow核心概念和原理介绍

    关于 TensorFlow TensorFlow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库. 节点(Nodes)在图中表示数学操作,图中的线(edges)则表示 ...

  4. Tensorflow 读写 tfrecord 文件(Python3)

    TensorFlow笔记博客:https://blog.csdn.net/xierhacker/article/category/6511974 写入tfrecord文件 import tensorf ...

  5. tensorflow制作tfrecord格式数据

    tf.Example msg tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据. To ...

  6. 《转》从系统和代码实现角度解析TensorFlow的内部实现原理 | 深度

    from https://www.leiphone.com/news/201702/n0uj58iHaNpW9RJG.html?viewType=weixin 摘要 2015年11月9日,Google ...

  7. tensorflow的tfrecord操作代码与数据协议规范

    tensorflow的数据集可以说是非常重要的部分,我认为人工智能就是数据加算法,数据没处理好哪来的算法? 对此tensorflow有一个专门管理数据集的方式tfrecord·在训练数据时提取图片与标 ...

  8. 移动端开发rem布局之less+媒体查询布局的原理步骤和心得

    rem即是以html文件中font-size的大小的倍数rem布局的原理:通过媒体查询设置不同屏幕宽度下的html的font-size大小,然后在css布局时用rem单位取代px,从而实现页面元素大小 ...

  9. 分享《机器学习实战基于Scikit-Learn和TensorFlow》中英文PDF源代码+《深度学习之TensorFlow入门原理与进阶实战》PDF+源代码

    下载:https://pan.baidu.com/s/1qKaDd9PSUUGbBQNB3tkDzw <机器学习实战:基于Scikit-Learn和TensorFlow>高清中文版PDF+ ...

随机推荐

  1. video标签的视频全屏

    按钮: <div class="fullScreen" @click="fullScreen"><i class="el-icon- ...

  2. 浏览器不支持promise的finally

    IE浏览器以及edge浏览器的不支持es6里面promise的finally 解决方法: 1.npm install axios promise.prototype.finally --save 2. ...

  3. 防止因提供的sql脚本有问题导致版本bvt失败技巧

    发版本时,可能会由于测试库和开发库表结构不一样而导致数据库脚本在测试那边执行时出错,导致版本BVT失败,以下技巧可解决此问题. 步骤:备份目标库,在备份库中执行将要提供的sql脚本看有无问题,若没问题 ...

  4. js学习笔记之自调用函数、闭包、原型链

     自调用函数 var name = 'world!'; // console.log(typeof name) (function () { console.log(this.name, name, ...

  5. js学习笔记之日期倒计时DOM操作

    1.访问html元素 getElementById() 方法  返回对拥有指定 id 的第一个对象的引用,只有dom对象有效 getElementsByName() 方法  返回指定名称的对象集合 g ...

  6. 从零开始了解kubernetes

    kubernetes 已经成为容器编排领域的王者,它是基于容器的集群编排引擎,具备扩展集群.滚动升级回滚.弹性伸缩.自动治愈.服务发现等多种特性能力. 本文将带着大家快速了解 kubernetes , ...

  7. java类与对象基础篇

    java面向对象基础篇 面向对象程序设计(Object Oriented Proframming ,OOP) 面向对象的本质是:以类的方式组织代码,以对象的方式组织(封装)数据. 面向对象的核心思想是 ...

  8. 利用 cgroup 的 cpuset 控制器限制进程的 CPU 使用

    最近在做一些性能测试的事情,首要前提是控制住 CPU 的使用量.最直观的方法无疑是安装 Docker,在每个配置了参数的容器里运行基准程序. 对于计算密集型任务,在只限制 CPU 的需求下,直接用 L ...

  9. Centos配置网络和主机映射

    目录 虚拟机网络的三种配置方式 配置虚拟机IP 主机映射问题 配置虚拟机的主机名 虚拟机远程登录 虚拟机网络的三种配置方式 桥接模式:当前虚拟机与主机在同一个局域网下,同一个局域网下的所有电脑都可以访 ...

  10. MySQL-16-主从复制进阶

    延时从库 介绍 延时从库: 是我们人为配置的一种特殊从库,人为配置从库和主库延时N小时 为什么要有延时从库 数据库故障 物理损坏,普通的主从复制非常擅长解决物理损坏 逻辑损坏,普通主从复制没办法解决逻 ...