tensorflow的数据集可以说是非常重要的部分,我认为人工智能就是数据加算法,数据没处理好哪来的算法?

对此tensorflow有一个专门管理数据集的方式tfrecord·在训练数据时提取图片与标签就更加方便,但是tensorflow

的使用可以说,有时还是会踩着坑的,对此我做了一个代码专门用于去制作tfrecord和读取tfrecord。

1.首先我们要整理数据集格式如下

是的就是这样每个类别的图片数据分别在一个文件夹图片的名字可以随意取,当然要都是相同的编码格式jpg,png之类。

我们在为这些图片按照这样的格式分好类了之拷贝整个路径就可以了

import os
import tensorflow as tf
import cv2 as cv
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_list('image_CNN_shape', [None, 40, 32, 1], 'image shape [high, width, pip]')
tf.flags.DEFINE_list('image_shape', [40, 32, 1], 'image shape [high, width, pip]')
tf.flags.DEFINE_list('label_CNN_shape', [None, 6], 'label shape is one-hot list [batch, sort]')
tf.flags.DEFINE_list('label_shape', [1], 'label shape ')
tf.flags.DEFINE_integer('batch_size', 20, 'one batch size ') def Reader(train_path):
'''
输入训练集的整个文件夹生成一个tf的训练文件
train_path
dir_name: 0开始是排序
file_name :1开始排序
:param train_path: 训练集路径
:return:
'''
# 1.生成图片文件队列
# 1.1生成分类的dir 列表
one_list = os.listdir(train_path) # 1.2路径添加完整
# list_dir = add_path(one_list, train_path)
# print(list_dir)
for i in range(len(one_list)):
one_list[i] = train_path + r'/' + str(i) all_image_list = []
all_label_list = []
# print(one_list)
for j in range(len(one_list)):
two_list = os.listdir(one_list[j])
for i in range(len(two_list)):
all_label_list.append(j)
all_image_list.append(one_list[j] + '/' + two_list[i]) print(len(all_label_list))
image_queue = tf.train.string_input_producer(all_image_list, shuffle=True)
# 2.构造阅读器
reader = tf.WholeFileReader()
# 3.读取图片
key, value = reader.read(image_queue)
# print(value)
# 4.解码数据
image = tf.image.decode_bmp(value)
image.set_shape([40, 32, 1]) # [高,宽,通道]
# print(image) # 5.批处理数据 Op_batch = tf.train.batch([image, key], batch_size=1254, num_threads=1) with tf.Session() as sess:
coor = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess=sess) # 开启队列的线程 image_data, label_data = sess.run(Op_batch)
label_list = []
lenth = len(label_data)
for i in range(lenth):
datalist = str(label_data[i]).split('/')
label_list.append(int(datalist[1]))
write_to_tfrecord(label_list, image_data, lenth) print('tfrecord write down')
coor.request_stop() # 发出所有线程终止信号
coor.join() # 等待所有的子线程加入主线程 def add_path(listdir, train_path):
for i in range(len(listdir)):
listdir[i] = train_path + r'/' + listdir[i]
return listdir def write_to_tfrecord(label_batch, image_batch, lenth):
'''
要点:避免在循环里面eval或者run :param label_batch: numpy类型
:param image_batch: numpy类型
:param lenth: int类型batch的长度
:return: None 会生成一个文件
'''
writer = tf.python_io.TFRecordWriter(path=r"./text.tfrecords")
label_batch = tf.constant(label_batch)
label_batch = tf.cast(label_batch, tf.uint8)
for i in range(lenth):
image = image_batch[i].tostring() label = label_batch[i].eval().tostring() # 构造协议块
# tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。
Example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}))
writer.write(Example.SerializeToString())
print('write: ', i)
writer.close() def read_tfrecord(path_list):
# 生成阅读器
reader = tf.TFRecordReader()
# 定义输入部分
file_queue = tf.train.string_input_producer(path_list, shuffle=False)
key, value = reader.read(file_queue)
# 解析value
features = tf.parse_single_example(serialized=value, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string)
}) image = tf.decode_raw(features['image'], tf.uint8)
label = tf.decode_raw(features['label'], tf.uint8)
image = tf.reshape(image, FLAGS.image_shape)
label = tf.reshape(label, FLAGS.label_shape) image_batch, label_batch = tf.train.batch([image, label], batch_size=FLAGS.batch_size, num_threads=1, capacity=FLAGS.batch_size)
print(image_batch, label_batch)
return image_batch, label_batch if __name__ == '__main__': train_path = r'E:\百度云下载\tf_car_license_dataset\train_images\training-set\chinese-characters'
dir_list = []
read_path_list = [r"./other.tfrecords", ]
Reader(train_path)

Reader就是制作tfrecord

read_tfrecord就是按照路径去读取数据读出来的数据的shape 是FLAGS.image_CNN_shape形状的数据,方便做卷积

注:在做数据集整理的时候我做了许多的尝试,由于这样对图片分类,制作数据的时候打标签才更容易,最容易的莫过于,制作的时候所有一类的都放在一起,

也就是前200个读取出来的都是0号,下一个读取出来的都是1号。。。结果这样的数据集卷积神经网络怎么都不收敛,很尬,我程序跑了一天了,准确率上不去,

我都以为是我模型构建错误的原因,结果还是找不出问题所在。后来我改变了数据集的制作方式,改成乱序制作,训练就非常高效的成功了。最后要补充的是,

当数据的准确率一直在震荡,那么你可以尝试着把学习率改的更小比如0.0001就好了。这个过程还是要多多实际操作。

2.制作tfrecord慢的原因,一定要记住在tensorflow里面的tensor和op的区别,run 和 eval tensor 会获得里面的数据,但是run 和 eval op则会执行这个op,

虽然都会出现函数的返回值一样的结果是因为op运行的结果出来了,如果在制作tfrecord的for循环里面存在eval或者run op会导致制作的过程异常的慢,几千个数据集可能要做一晚上。

举个反面例子

def Reader(train_path):
'''
输入训练集的整个文件夹生成一个tf的训练文件
train_path
dir_name: 0开始是排序
file_name :1开始排序
:param train_path: 训练集路径
:return:
'''
# 1.生成图片文件队列
# 1.1生成分类的dir 列表
one_list = os.listdir(train_path) # 1.2路径添加完整
# list_dir = add_path(one_list, train_path)
# print(list_dir)
for i in range(len(one_list)):
one_list[i] = train_path + r'/' + str(i) all_image_list = []
all_label_list = []
print(one_list)
for j in range(len(one_list)):
two_list = os.listdir(one_list[j])
for i in range(len(two_list)):
all_label_list.append(j)
all_image_list.append(one_list[j] + '/' + two_list[i]) print('%s:'%j,len(two_list)) # 校验 print(all_label_list)
lenth = len(all_label_list)
lenth_image = len(all_image_list)
print('label len:', lenth)
print('image len: ', lenth_image)
image_queue = tf.train.string_input_producer(all_image_list, shuffle=False)
# 2.构造阅读器
reader = tf.WholeFileReader()
# 3.读取图片
key, value = reader.read(image_queue)
# print(value)
# 4.解码数据
image = tf.image.decode_bmp(value)
image.set_shape([40, 32, 1]) # [高,宽,通道]
# print(image) # 5.批处理数据 image_batch_op = tf.train.batch([image], batch_size=lenth, num_threads=1) with tf.Session() as sess:
coor = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess=sess) # 开启队列的线程
write_op = write_to_tfrecord(all_label_list, image_batch_op, lenth)
print('tfrecord write down')
coor.request_stop() # 发出所有线程终止信号
coor.join() # 等待所有的子线程加入主线程 def write_to_tfrecord(label_batch, image_batch, lenth):
writer = tf.python_io.TFRecordWriter(path=r"./mnist_data/other1.tfrecords")
label_batch = tf.constant(label_batch)
label_batch = tf.cast(label_batch, tf.uint8)
for i in range(lenth):
image = image_batch[i].eval().tostring() # 在这里eval()的话就会很慢 类似于每一次都run了一下image_batch的这个op--也算是个反面教材吧
label = label_batch[i].eval().tostring() # 构造协议块
# tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。
Example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}))
writer.write(Example.SerializeToString())
print('write: ', i)
writer.close()

这里传入写入函数的image_batch的是一个op 所以在函数里面需要每一个都eval,导致程序很慢。因为每一次eval和run一个op需要牵扯到很多的数据计算。最好在循环外面就完成这个操作。

tensorflow的tfrecord操作代码与数据协议规范的更多相关文章

  1. tensorflow制作tfrecord格式数据

    tf.Example msg tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据. To ...

  2. 3. Tensorflow生成TFRecord

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  3. Tensorflow之TFRecord的原理和使用心得

    本文首发于微信公众号「对白的算法屋」 大家好,我是对白. 目前,越来越多的互联网公司内部都有自己的一套框架去训练模型,而模型训练时需要的数据则都保存在分布式文件系统(HDFS)上.Hive作为构建在H ...

  4. 【开源】OSharp3.0框架解说系列(6.2):操作日志与数据日志

    OSharp是什么? OSharp是个快速开发框架,但不是一个大而全的包罗万象的框架,严格的说,OSharp中什么都没有实现.与其他大而全的框架最大的不同点,就是OSharp只做抽象封装,不做实现.依 ...

  5. Scala 深入浅出实战经典 第39讲:ListBuffer、ArrayBuffer、Queue、Stack操作代码实战

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  6. PHP操作二进制字节数据

    在PHP开发中大都是操作字符类数据,极为方便,但操作二进制又如何呢,下面代码举例看看. 函数:  pack(format,args+) pack()和unpack()函数的第一个参数表如下 Bash ...

  7. 使用JWPL (Java Wikipedia Library)操作维基百科数据

    使用JWPL (Java Wikipedia Library)操作维基百科数据 1. JWPL介绍 JWPL(Java Wikipedia Library)是一个开源的访问wikipeida数据的Ja ...

  8. tensorflow学习笔记(10) mnist格式数据转换为TFrecords

    本程序 (1)mnist的图片转换成TFrecords格式 (2) 读取TFrecords格式 # coding:utf-8 # 将MNIST输入数据转化为TFRecord的格式 # http://b ...

  9. ch6-定制数据对象(打包代码和数据)

    为了看出数据属于哪个选手,教练向各个选手的数据文件中添加了标识数据:选手全名,出生日期,计时数据. 例如:sarah文件的数据更新为: Sarah Sweeney,2002-6-17,2:58,2.5 ...

随机推荐

  1. linux内存优化之手工释放linux内存

    先介绍下free命令 Linux free命令用于显示内存状态. free指令会显示内存的使用情况,包括实体内存,虚拟的交换文件内存,共享内存区段,以及系统核心使用的缓冲区等. 语法: free [- ...

  2. MeteoInfoLab脚本示例:中文处理

    在脚本中使用中文需要指明是unicode编码,即在含有中文的字符串前加u,比如:u'中文'.还需要将字体指定为一种中文字体.详见下面的例子.脚本程序: x = [1,2,3,4] y = [1,4,9 ...

  3. 【全网免费VIP观看】哔哩哔哩番剧解锁大会员-集合了优酷-爱奇艺-腾讯-芒果-乐视-ab站等全网vip视频免费破解去广告-高清普清电视观看-持续更新

    哔哩哔哩番剧解锁大会员-集合了优酷-爱奇艺-腾讯-芒果-乐视-ab站等全网vip视频免费破解去广告-高清普清电视观看-持续更新 前言 突然想看电视,结果 没有VIP 又不想花钱,这免费的不久来啦. 示 ...

  4. mongodb安装及使用

    安装命令: sudo apt-get install mongodb 开始认证,创建用户: 编辑配置文件: sudo vim /etc/mongodb.conf 11行中的 bind_ip值 修改成为 ...

  5. 【应用服务 App Service】NodeJS +Egg 发布到App Service时遇见 [ERR_SYSTEM_ERROR]: A system error occurred:uv_os_get_passwd returned ENOENT(no such file or directory)

    问题情形 本地NodeJS应用使用Egg脚手架构建,本地运行测试完全没有问题,发布后App Service后不能运行.通过登录到kudu后(https://<your web site>. ...

  6. 使用pyenv实现python多版本共存

    背景 如果是Ubuntu等桌面系统,都已经更新到了Python较新的版本.但多数生产环境使用的还是红帽系统. CentOS7默认还是Python2.7,而开发环境如果是高版本Python就带来了问题. ...

  7. linux硬盘分区及挂载

    今天买的一台服务器发现其硬盘容量与购买界面的描述不符,于是我去问了客服才知道有一块硬盘需要自己挂载,所以记录自己硬盘分区以及挂载操作得此文. 测试环境 ​ 由于时间限制,本人仅在centos 8下测试 ...

  8. Java语言对对象采用的是引用传递还是按值传递?

    按值调用表示方法接收的是调用者提供的值:而按引用调用表示方法接收的是调用者提供的变量地址:一个方法可以修改传递引用所对应的变量值, 而不能修改传递值调用所对应的变量值: Java语言对对象采用的是引用 ...

  9. 看完这篇良心帖!你的Python入门基础就差不多了

    有段时间没跟各位粉丝分享编程资源福利了,看了下自己的资料夹,就剩下我认为比较好的Python学习资料了.相信这套资料可以对你进阶高级工程师有帮助!全民学Python的话题铺天盖地,中国的Python学 ...

  10. (静默安装)Cent OS 6_5(x86_64)下安装Oracle 11g

    Cent OS 6_5(x86_64)下安装Oracle 11g 1 硬件要求   1.1 内存 & swap 物理内存不少于1G 硬盘可以空间不少于5G swap分区空间不少于2G Mini ...