tensorflow的数据集可以说是非常重要的部分,我认为人工智能就是数据加算法,数据没处理好哪来的算法?

对此tensorflow有一个专门管理数据集的方式tfrecord·在训练数据时提取图片与标签就更加方便,但是tensorflow

的使用可以说,有时还是会踩着坑的,对此我做了一个代码专门用于去制作tfrecord和读取tfrecord。

1.首先我们要整理数据集格式如下

是的就是这样每个类别的图片数据分别在一个文件夹图片的名字可以随意取,当然要都是相同的编码格式jpg,png之类。

我们在为这些图片按照这样的格式分好类了之拷贝整个路径就可以了

import os
import tensorflow as tf
import cv2 as cv
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_list('image_CNN_shape', [None, 40, 32, 1], 'image shape [high, width, pip]')
tf.flags.DEFINE_list('image_shape', [40, 32, 1], 'image shape [high, width, pip]')
tf.flags.DEFINE_list('label_CNN_shape', [None, 6], 'label shape is one-hot list [batch, sort]')
tf.flags.DEFINE_list('label_shape', [1], 'label shape ')
tf.flags.DEFINE_integer('batch_size', 20, 'one batch size ') def Reader(train_path):
'''
输入训练集的整个文件夹生成一个tf的训练文件
train_path
dir_name: 0开始是排序
file_name :1开始排序
:param train_path: 训练集路径
:return:
'''
# 1.生成图片文件队列
# 1.1生成分类的dir 列表
one_list = os.listdir(train_path) # 1.2路径添加完整
# list_dir = add_path(one_list, train_path)
# print(list_dir)
for i in range(len(one_list)):
one_list[i] = train_path + r'/' + str(i) all_image_list = []
all_label_list = []
# print(one_list)
for j in range(len(one_list)):
two_list = os.listdir(one_list[j])
for i in range(len(two_list)):
all_label_list.append(j)
all_image_list.append(one_list[j] + '/' + two_list[i]) print(len(all_label_list))
image_queue = tf.train.string_input_producer(all_image_list, shuffle=True)
# 2.构造阅读器
reader = tf.WholeFileReader()
# 3.读取图片
key, value = reader.read(image_queue)
# print(value)
# 4.解码数据
image = tf.image.decode_bmp(value)
image.set_shape([40, 32, 1]) # [高,宽,通道]
# print(image) # 5.批处理数据 Op_batch = tf.train.batch([image, key], batch_size=1254, num_threads=1) with tf.Session() as sess:
coor = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess=sess) # 开启队列的线程 image_data, label_data = sess.run(Op_batch)
label_list = []
lenth = len(label_data)
for i in range(lenth):
datalist = str(label_data[i]).split('/')
label_list.append(int(datalist[1]))
write_to_tfrecord(label_list, image_data, lenth) print('tfrecord write down')
coor.request_stop() # 发出所有线程终止信号
coor.join() # 等待所有的子线程加入主线程 def add_path(listdir, train_path):
for i in range(len(listdir)):
listdir[i] = train_path + r'/' + listdir[i]
return listdir def write_to_tfrecord(label_batch, image_batch, lenth):
'''
要点:避免在循环里面eval或者run :param label_batch: numpy类型
:param image_batch: numpy类型
:param lenth: int类型batch的长度
:return: None 会生成一个文件
'''
writer = tf.python_io.TFRecordWriter(path=r"./text.tfrecords")
label_batch = tf.constant(label_batch)
label_batch = tf.cast(label_batch, tf.uint8)
for i in range(lenth):
image = image_batch[i].tostring() label = label_batch[i].eval().tostring() # 构造协议块
# tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。
Example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}))
writer.write(Example.SerializeToString())
print('write: ', i)
writer.close() def read_tfrecord(path_list):
# 生成阅读器
reader = tf.TFRecordReader()
# 定义输入部分
file_queue = tf.train.string_input_producer(path_list, shuffle=False)
key, value = reader.read(file_queue)
# 解析value
features = tf.parse_single_example(serialized=value, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string)
}) image = tf.decode_raw(features['image'], tf.uint8)
label = tf.decode_raw(features['label'], tf.uint8)
image = tf.reshape(image, FLAGS.image_shape)
label = tf.reshape(label, FLAGS.label_shape) image_batch, label_batch = tf.train.batch([image, label], batch_size=FLAGS.batch_size, num_threads=1, capacity=FLAGS.batch_size)
print(image_batch, label_batch)
return image_batch, label_batch if __name__ == '__main__': train_path = r'E:\百度云下载\tf_car_license_dataset\train_images\training-set\chinese-characters'
dir_list = []
read_path_list = [r"./other.tfrecords", ]
Reader(train_path)

Reader就是制作tfrecord

read_tfrecord就是按照路径去读取数据读出来的数据的shape 是FLAGS.image_CNN_shape形状的数据,方便做卷积

注:在做数据集整理的时候我做了许多的尝试,由于这样对图片分类,制作数据的时候打标签才更容易,最容易的莫过于,制作的时候所有一类的都放在一起,

也就是前200个读取出来的都是0号,下一个读取出来的都是1号。。。结果这样的数据集卷积神经网络怎么都不收敛,很尬,我程序跑了一天了,准确率上不去,

我都以为是我模型构建错误的原因,结果还是找不出问题所在。后来我改变了数据集的制作方式,改成乱序制作,训练就非常高效的成功了。最后要补充的是,

当数据的准确率一直在震荡,那么你可以尝试着把学习率改的更小比如0.0001就好了。这个过程还是要多多实际操作。

2.制作tfrecord慢的原因,一定要记住在tensorflow里面的tensor和op的区别,run 和 eval tensor 会获得里面的数据,但是run 和 eval op则会执行这个op,

虽然都会出现函数的返回值一样的结果是因为op运行的结果出来了,如果在制作tfrecord的for循环里面存在eval或者run op会导致制作的过程异常的慢,几千个数据集可能要做一晚上。

举个反面例子

def Reader(train_path):
'''
输入训练集的整个文件夹生成一个tf的训练文件
train_path
dir_name: 0开始是排序
file_name :1开始排序
:param train_path: 训练集路径
:return:
'''
# 1.生成图片文件队列
# 1.1生成分类的dir 列表
one_list = os.listdir(train_path) # 1.2路径添加完整
# list_dir = add_path(one_list, train_path)
# print(list_dir)
for i in range(len(one_list)):
one_list[i] = train_path + r'/' + str(i) all_image_list = []
all_label_list = []
print(one_list)
for j in range(len(one_list)):
two_list = os.listdir(one_list[j])
for i in range(len(two_list)):
all_label_list.append(j)
all_image_list.append(one_list[j] + '/' + two_list[i]) print('%s:'%j,len(two_list)) # 校验 print(all_label_list)
lenth = len(all_label_list)
lenth_image = len(all_image_list)
print('label len:', lenth)
print('image len: ', lenth_image)
image_queue = tf.train.string_input_producer(all_image_list, shuffle=False)
# 2.构造阅读器
reader = tf.WholeFileReader()
# 3.读取图片
key, value = reader.read(image_queue)
# print(value)
# 4.解码数据
image = tf.image.decode_bmp(value)
image.set_shape([40, 32, 1]) # [高,宽,通道]
# print(image) # 5.批处理数据 image_batch_op = tf.train.batch([image], batch_size=lenth, num_threads=1) with tf.Session() as sess:
coor = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess=sess) # 开启队列的线程
write_op = write_to_tfrecord(all_label_list, image_batch_op, lenth)
print('tfrecord write down')
coor.request_stop() # 发出所有线程终止信号
coor.join() # 等待所有的子线程加入主线程 def write_to_tfrecord(label_batch, image_batch, lenth):
writer = tf.python_io.TFRecordWriter(path=r"./mnist_data/other1.tfrecords")
label_batch = tf.constant(label_batch)
label_batch = tf.cast(label_batch, tf.uint8)
for i in range(lenth):
image = image_batch[i].eval().tostring() # 在这里eval()的话就会很慢 类似于每一次都run了一下image_batch的这个op--也算是个反面教材吧
label = label_batch[i].eval().tostring() # 构造协议块
# tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。
Example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}))
writer.write(Example.SerializeToString())
print('write: ', i)
writer.close()

这里传入写入函数的image_batch的是一个op 所以在函数里面需要每一个都eval,导致程序很慢。因为每一次eval和run一个op需要牵扯到很多的数据计算。最好在循环外面就完成这个操作。

tensorflow的tfrecord操作代码与数据协议规范的更多相关文章

  1. tensorflow制作tfrecord格式数据

    tf.Example msg tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据. To ...

  2. 3. Tensorflow生成TFRecord

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  3. Tensorflow之TFRecord的原理和使用心得

    本文首发于微信公众号「对白的算法屋」 大家好,我是对白. 目前,越来越多的互联网公司内部都有自己的一套框架去训练模型,而模型训练时需要的数据则都保存在分布式文件系统(HDFS)上.Hive作为构建在H ...

  4. 【开源】OSharp3.0框架解说系列(6.2):操作日志与数据日志

    OSharp是什么? OSharp是个快速开发框架,但不是一个大而全的包罗万象的框架,严格的说,OSharp中什么都没有实现.与其他大而全的框架最大的不同点,就是OSharp只做抽象封装,不做实现.依 ...

  5. Scala 深入浅出实战经典 第39讲:ListBuffer、ArrayBuffer、Queue、Stack操作代码实战

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-64讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  6. PHP操作二进制字节数据

    在PHP开发中大都是操作字符类数据,极为方便,但操作二进制又如何呢,下面代码举例看看. 函数:  pack(format,args+) pack()和unpack()函数的第一个参数表如下 Bash ...

  7. 使用JWPL (Java Wikipedia Library)操作维基百科数据

    使用JWPL (Java Wikipedia Library)操作维基百科数据 1. JWPL介绍 JWPL(Java Wikipedia Library)是一个开源的访问wikipeida数据的Ja ...

  8. tensorflow学习笔记(10) mnist格式数据转换为TFrecords

    本程序 (1)mnist的图片转换成TFrecords格式 (2) 读取TFrecords格式 # coding:utf-8 # 将MNIST输入数据转化为TFRecord的格式 # http://b ...

  9. ch6-定制数据对象(打包代码和数据)

    为了看出数据属于哪个选手,教练向各个选手的数据文件中添加了标识数据:选手全名,出生日期,计时数据. 例如:sarah文件的数据更新为: Sarah Sweeney,2002-6-17,2:58,2.5 ...

随机推荐

  1. 一文看懂Vue3.0的优化

    1.源码优化: a.使用monorepo来管理源码 Vue.js 2.x 的源码托管在 src 目录,然后依据功能拆分出了 compiler(模板编译的相关代码).core(与平台无关的通用运行时代码 ...

  2. MeteoInfo脚本示例:读取FY3A AOD HDF文件

    FY3A卫星有AOD产品数据,HDF格式,这里示例用MeteoInfo脚本程序读取和显示该类数据. 脚本程序如下: #----------------------------------------- ...

  3. 用-pthread替代-lpthread

    -pthread 在多数系统中,-pthread会被展开为"-D_REENTRANT -lpthread".作为编译参数可以通知系统函数开启多线程安全特性,比如将errno定义线程 ...

  4. selenium 图片懒加载

    from selenium import webdriver options = webdriver.ChromeOptions() prefs = {} prefs['profile.managed ...

  5. 彩贝网app破解登入参数(涉及app脱壳,反编译java层,so层动态注册,反编译so层)

    一.涉及知识点 app脱壳 java层 so层动态注册 二.抓包信息 POST /user/login.html HTTP/1.1 x-app-session: 1603177116420 x-app ...

  6. Linux-京西百花山

    百花山有三个收票的入口,分别在门头沟(G109).房山(G108)和河北 108有两个方向上百花山,史家营和四马台.只有史家营方向能开车到山顶. 四马台那边,不住,要坐景区车才行 尽头是1900多米的 ...

  7. Docker学习笔记之-通过Xshell连接 CentOS服务

    上一节演示如何在虚拟机中安装 CentOS服务,Docker学习笔记之-在虚拟机VM上安装CentOS 7.8 本节主要演示如何通过 Xshell软件链接CentOS服务,本例以虚拟机作为演示,直接在 ...

  8. 国内首个 .NET 5 框架 Fur 斩获 1000 stars,1.0.0-rc.final.20 发布

          Fur 是 .NET 5 平台下企业应用开发最佳实践框架. 通往牛逼的路上,风景差得让人只想说脏话,但我在意的是远方. 啥环境 早在 1998 年微软公司对外发布 .NET/C# 平台的那 ...

  9. java实现单链表、栈、队列三种数据结构

    一.单链表 1.在我们数据结构中,单链表非常重要.它里面的数据元素是以结点为单位,每个结点是由数据元素的数据和下一个结点的地址组成,在java集合框架里面 LinkedList.HashMap(数组加 ...

  10. Nginx 配置请求响应时间

    1.常见默认nginx.conf配置日志格式 log_format main '$remote_addr - $remote_user [$time_local] "$request&quo ...