当训练数据量较小时,采用直接读取文件的方式,当训练数据量非常大时,直接读取文件的方式太耗内存,这时应采用高效的读取方法,读取tfrecords文件,这其实是一种二进制文件。tensorflow为其内置了各种存储和读取的函数,方便调用。

  不知道为啥,从tfrecords中读取数据用于训练时,收敛得更快,更平稳。上面两个图是使用tfrecords的准确率和loss值变化,下面是直接读取文件的准确率和loss值变化。

1 生成记录样本的记录文件

 root_dir = os.getcwd()

 def getTrianList():
with open("train.txt","w") as f:
for file in os.listdir(root_dir+'\\dataSet'):
for picFile in os.listdir(root_dir+"\\dataSet\\"+file):
f.write("dataSet/"+file+"/"+picFile+" "+file+"\n")
print(picFile)
if __name__=="__main__":
getTrianList()

  将样本文件路径和标签统一记录到一个txt中,后面生成tfrecords文件就是通过读取这些信息。

  

  注意文件路径和标签之间采用空格,不要使用制表符。

2 读取txt存于数组中

 def load_file(example_list_file):
lines = np.genfromtxt(example_list_file,delimiter=" ",dtype=[('col1', 'S120'), ('col2', 'i8')])
examples = []
labels = []
for example,label in lines:
examples.append(example)
labels.append(label)
#convert to numpy array
return np.asarray(examples),np.asarray(labels),len(lines)

  这段代码主要用来读取第1步生成的txt,将文件路径和标签存于数组中

3 读取图片

 def extract_image(filename,height,width):
print(filename)
image = cv2.imread(filename)
image = cv2.resize(image,(height,width))
b,g,r = cv2.split(image)
rgb_image = cv2.merge([r,g,b])
return rgb_image

  使用cv2读取图片文件

4 转化为tfrecords文件

 def trans2tfRecord(trainFile,name,output_dir,height,width):
if not os.path.exists(output_dir) or os.path.isfile(output_dir):
os.makedirs(output_dir)
_examples,_labels,examples_num = load_file(train_file)
filename = name + '.tfrecords'
writer = tf.python_io.TFRecordWriter(filename)
for i,[example,label] in enumerate(zip(_examples,_labels)):
print("NO{}".format(i))
#need to convert the example(bytes) to utf-8
example = example.decode("UTF-8")
image = extract_image(example,height,width)
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw':_bytes_feature(image_raw),
'height':_int64_feature(image.shape[0]),
'width': _int64_feature(32),
'depth': _int64_feature(32),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
 def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

5 从tfrecords中读取训练数据

 def read_tfRecord(file_tfRecord):
queue = tf.train.string_input_producer([file_tfRecord])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width':tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
image = tf.decode_raw(features['image_raw'],tf.uint8)
#height = tf.cast(features['height'], tf.int64)
#width = tf.cast(features['width'], tf.int64)
image = tf.reshape(image,[32,32,3])
image = tf.cast(image, tf.float32)
image = tf.image.per_image_standardization(image)
label = tf.cast(features['label'], tf.int64)
print(image,label)
return image,label

  从tfrecords文件中读取image和label,训练的时候,直接使用tf.train.batch函数生成用于训练的batch即可。

 image_batches,label_batches = tf.train.batch([image, label], batch_size=16, capacity=20)

  其余的部分跟之前的训练步骤一样。

tensorflowxun训练自己的数据集之从tfrecords读取数据的更多相关文章

  1. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  2. Fast RCNN 训练自己的数据集(3训练和检测)

    转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https://github.com/YihangLou/fas ...

  3. 【faster-rcnn】训练自己的数据集时的坑

    既然faster-rcnn原版发表时候是matlab版代码,那就用matlab版代码吧!不过遇到的坑挺多的,不知道python版会不会好一点. ======= update ========= 总体上 ...

  4. 【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控

    [写在前面] 用Tensorflow(TF)已实现好的卷积神经网络(CNN)模型来训练自己的数据集,验证目前较成熟模型在不同数据集上的准确度,如Inception_V3, VGG16,Inceptio ...

  5. 目标检测算法SSD之训练自己的数据集

    目标检测算法SSD之训练自己的数据集 prerequesties 预备知识/前提条件 下载和配置了最新SSD代码 git clone https://github.com/weiliu89/caffe ...

  6. 可变卷积Deforable ConvNet 迁移训练自己的数据集 MXNet框架 GPU版

    [引言] 最近在用可变卷积的rfcn 模型迁移训练自己的数据集, MSRA官方使用的MXNet框架 环境搭建及配置:http://www.cnblogs.com/andre-ma/p/8867031. ...

  7. caffe训练自己的数据集

    默认caffe已经编译好了,并且编译好了pycaffe 1 数据准备 首先准备训练和测试数据集,这里准备两类数据,分别放在文件夹0和文件夹1中(之所以使用0和1命名数据类别,是因为方便标注数据类别,直 ...

  8. 使用yolo3模型训练自己的数据集

    使用yolo3模型训练自己的数据集 本项目地址:https://github.com/Cw-zero/Retrain-yolo3 一.运行环境 1. Ubuntu16.04. 2. TensorFlo ...

  9. 【tf.keras】在 cifar 上训练 AlexNet,数据集过大导致 OOM

    cifar-10 每张图片的大小为 32×32,而 AlexNet 要求图片的输入是 224×224(也有说 227×227 的,这是 224×224 的图片进行大小为 2 的 zero paddin ...

随机推荐

  1. CStringArray序列化处理

    开发中需要对CStringArray进行保存操作,涉及到序列化,特总结一下: //写 CStringArray saTmp1; CStringArray saTmp2 saTmp1.AddString ...

  2. ios 如何对UITableView中的内容进行排序

    - (UITableViewCellEditingStyle)tableView:(UITableView *)tableView editingStyleForRowAtIndexPath:(NSI ...

  3. CH5402 选课【树形DP】【背包】

    5402 选课 0x50「动态规划」例题 描述 学校实行学分制.每门的必修课都有固定的学分,同时还必须获得相应的选修课程学分.学校开设了 N(N≤300) 门的选修课程,每个学生可选课程的数量 M 是 ...

  4. 【MongoDB】从入门到精通mongdb系列学习宝典,想学mongodb小伙伴请进来

    最近一段时间在学习MongoDB,在学习过程中总共编写了四十余篇博客.从mongodb软件下载到分片集群的搭建. 从理论讲解到实例练习.现在把所有博客的内容做个简单目录,方便阅读的小伙伴查询. 一. ...

  5. JQueryiframe页面操作父页面中的元素与方法(实例讲解)

    1)在iframe中查找父页面元素的方法:$('#id', window.parent.document) 2)在iframe中调用父页面中定义的方法和变量:parent.methodparent.v ...

  6. 二项分布。计算binomial(100,50,0.25)将会产生的递归调用次数(算法第四版1.1.27)

    算法第四版35页问题1.1.27,估计用一下代码计算binomial(100,50,0.25)将会产生的递归调用次数: public static double binomial(int n,int ...

  7. 3.2 - FTP文件上传下载

    题目:开发一个支持多用户同时在线的FTP程序要求:1.用户加密认证2.允许同时多用户登录3.每个用户有自己的家目录,且只能访问自己的家目录4.对用户进行磁盘配额,每个用户的可用空间不同5.允许用户在f ...

  8. HBase简单API

    一.使用IDEA的maven工程,工程结构如下: 二.maven的依赖pom.xml文件 <?xml version="1.0" encoding="UTF-8&q ...

  9. 棋盘格 测量 相机近似精度 (像素精度&物理精度)

    像素精度计算 像素精度——一像素对应多少毫米——距离不同像素精度也不同 将棋盘格与相机CCD平面大致平行摆放,通过[每个点处的近似像素精度=相邻两个角点之间的实际距离(棋盘格尺寸已知)/ 棋盘格上检出 ...

  10. Web项目管理工具精选(下)

    原文:Web项目管理工具精选(下) 我们在上篇中已推介『代码管理.任务管理.支付工具.数据记录.Dashboard Analytics.客户支持』六个方面的工具.本文将介绍剩下七类工具. A/B测试 ...