http://blog.csdn.net/sinat_16823063/article/details/53946549

Tensorflow创建和读取17flowers数据集

标签: tensorflow
2016-12-30 21:43 1548人阅读 评论(8) 收藏 举报
 分类:
深度学习(4) 
    近期开始学习tensorflow,看了很多视频教程以及博客,大多数前辈在介绍tensorflow的用法时都会调用官方文档里给出的数据集,但是对于我这样的小白来说,如果想训练自己的数据集,自己将图片转换成可以输入到网络中的格式确实是有难度。但如果不会做图片的预处理,迈不出这一步,今后的学习之路会越来越难走,所以今天还是硬着头皮把我这几天已经实现的部分做一个总结。主要参考了一篇博客,文章最后有链接,通过这位博主的方法我成功生成了自己的数据集。
    首先,介绍一下用到的两个库,一个是os,一个是PIL。PIL(Python Imaging Library)是 Python 中最常用的图像处理库,而Image类又是 PIL库中一个非常重要的类,通过这个类来创建实例可以有直接载入图像文件,读取处理过的图像和通过抓取的方法得到的图像这三种方法。
    我采用的数据集是17 Category Flower Dataset。17flowers是牛津大学Visual Geometry Group选取的在英国比较常见的17种花。其中每种花有80张图片,整个数据及有1360张图片,可以在官网下载。不过在后续的训练过程中遇到了过拟合的问题,稍后再解释。
    由于17-flower数据集的结构如下图所示,标签就是最外层的文件夹的名字。所以在输入标签的时候可以直接通过文件读取的方式。
 
    我们是通过TFRecords来创建数据集的,TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(label)。

  1. import os
  2. import tensorflow as tf
  3. from PIL import Image
  4. cwd = os.getcwd()
  5. classes = os.listdir(cwd+"/17flowers/jpg")
  6. writer = tf.python_io.TFRecordWriter("train.tfrecords")
  7. for index, name in enumerate(classes):
  8. class_path = cwd + "/17flowers/jpg/" + name + "/"
  9. if os.path.isdir(class_path):
  10. for img_name in os.listdir(class_path):
  11. img_path = class_path + img_name
  12. img = Image.open(img_path)
  13. img = img.resize((224, 224))
  14. img_raw = img.tobytes()              #将图片转化为原生bytes
  15. example = tf.train.Example(features=tf.train.Features(feature={
  16. "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])),
  17. 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
  18. }))
  19. writer.write(example.SerializeToString())  #序列化为字符串
  20. writer.close()
  21. print(img_name)

我们使用tf.train.Example来定义我们要填入的数据格式,其中label即为标签,也就是最外层的文件夹名字,img_raw为易经理二进制化的图片。然后使用tf.python_io.TFRecordWriter来写入。基本的,一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List。就这样,我们把相关的信息都存到了一个文件中,所以前面才说不用单独的label文件。而且读取也很方便。

下面测试一下,已经存好的训练集是否可用:

  1. for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
  2. example = tf.train.Example()
  3. example.ParseFromString(serialized_example)
  4. image = example.features.feature['image'].bytes_list.value
  5. label = example.features.feature['label'].int64_list.value
  6. # 可以做一些预处理之类的
  7. print image, label

可以输出值,那么现在我们创建好的数据集已经存储在了统计目录下的train.tfrecords中了。接下来任务就是通过队列(queue)来读取这个训练集中的数据。

  1. def read_and_decode(filename):

  2. #根据文件名生成一个队列

  3. filename_queue = tf.train.string_input_producer([filename])


  4. reader = tf.TFRecordReader()

  5. _, serialized_example = reader.read(filename_queue)
  6. #返回文件名和文件

  7. features = tf.parse_single_example(serialized_example,
features={

  8. 'label': tf.FixedLenFeature([], tf.int64),
                                                                    'img_raw' : tf.FixedLenFeature([], tf.string),
})


  9. img = tf.decode_raw(features['img_raw'], tf.uint8)

  10. img = tf.reshape(img, [224, 224, 3])

  11. img = tf.cast(img, tf.float32) * (1. / 255) - 0.5

  12. label = tf.cast(features['label'], tf.int64)


  13. return img, label

其中的filename,即刚刚通过TFReader来生成的训练集。通过将其转化成string类型数据,再通过reader来读取队列中的文件,并通过features的名字,‘label’和‘img_raw’来得到对应的标签和图片数据。之后就是一系列的转码和reshape的工作了。

    准备好了这些训练集,接下来就是利用得到的label和img进行网络的训练了。
  1. img, label = read_and_decode("train.tfrecords")

  2. img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=100, capacity=2000,
 min_after_dequeue=1000)

  3. labels = tf.one_hot(label_batch,17,1,0)
  4. 
coord = tf.train.Coordinator()

  5. threads = tf.train.start_queue_runners(coord=coord,sess=sess)
  6. 

for i in range(200):

  7. batch_xs, batch_ys = sess.run([img_batch, labels])

  8. print(sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5}))

  9. print("Loss:", sess.run(cross_entropy,feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5}))

  10. if i % 50 == 0:

  11. print(compute_accuracy(mnist.test.images, mnist.test.labels))

  12. coord.request_stop()
  13. 
coord.join()

注意一点,由于这里使用了队列的方式来进行训练集的读取,所以异步方式,通过Coordinator让queue runner通过coordinator来启动这些线程,并在最后读取队列结束后终止线程。

    不过,在训练这个训练集的过程中不断的输出loss函数值,发现只迭代了5次就为0了,目前想到的原因可能是训练集太小,每个类只有80张图片。另一个原因可能是网络结构太深,由于使用了VGGNet,训练参数太多,容易过拟合。下次做个小规模的网络测试一下。

Tensorflow创建和读取17flowers数据集的更多相关文章

  1. 在C#下使用TensorFlow.NET训练自己的数据集

    在C#下使用TensorFlow.NET训练自己的数据集 今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分 ...

  2. (第二章第三部分)TensorFlow框架之读取二进制数据

    系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html (第二章第二部分)Tens ...

  3. tensorflow之数据读取探究(2)

    tensorflow之tfrecord数据读取 Tensorflow关于TFRecord格式文件的处理.模型的训练的架构为: 1.获取文件列表.创建文件队列:http://blog.csdn.net/ ...

  4. 【猫狗数据集】谷歌colab之使用pytorch读取自己数据集(猫狗数据集)

    之前在:https://www.cnblogs.com/xiximayou/p/12398285.html创建好了数据集,将它上传到谷歌colab 在colab上的目录如下: 在utils中的rdat ...

  5. TensorFlow从0到1之TensorFlow逻辑回归处理MNIST数据集(17)

    本节基于回归学习对 MNIST 数据集进行处理,但将添加一些 TensorBoard 总结以便更好地理解 MNIST 数据集. MNIST由https://www.tensorflow.org/get ...

  6. TensorFlow从0到1之TensorFlow csv文件读取数据(14)

    大多数人了解 Pandas 及其在处理大数据文件方面的实用性.TensorFlow 提供了读取这种文件的方法. 前面章节中,介绍了如何在 TensorFlow 中读取文件,本节将重点介绍如何从 CSV ...

  7. C#无限极分类树-创建-排序-读取 用Asp.Net Core+EF实现之方法二:加入缓存机制

    在上一篇文章中我用递归方法实现了管理菜单,在上一节我也提到要考虑用缓存,也算是学习一下.Net Core的缓存机制. 关于.Net Core的缓存,官方有三种实现: 1.In Memory Cachi ...

  8. [转载]MongoDB学习 (四):创建、读取、更新、删除(CRUD)快速入门

    本文介绍数据库的4个基本操作:创建.读取.更新和删除(CRUD). 接下来的数据库操作演示,我们使用MongoDB自带简洁但功能强大的JavaScript shell,MongoDB shell是一个 ...

  9. excel2003和excel2007文件的创建和读取

    excel2003和excel2007文件的创建和读取在项目中用的很多,首先我们要了解excel的常用组件和基本操作步骤. 常用组件如下所示: HSSFWorkbook excel的文档对象 HSSF ...

随机推荐

  1. scala学习笔记(7)

    1.包 --------------------------------------- Scala中的包和Java或者C++中命名空间的目的是相同的:管理大型程序中的名称. package a{ pa ...

  2. Scala学习笔记(6)对象

    1.单例对象.Scala没有静态方法或字段,可以使用object这个语法结构来达到同样的目的.对象定义了单个实例,包含了你想要的特性. object Accounts{ def newUniqueNu ...

  3. css实现斑马线效果

    文本实现斑马线效果 <style> p { font-size: 17px; line-height: 25px; background-color: antiquewhite; back ...

  4. jQuery进阶第四天(2019 10.13)

    1 初识面向对象(面向对象是一种思维方式) 以前写的代码 var name = '莉莉'; var sex = '女'; var age = 18; var name1 = '小明'; var sex ...

  5. 移动端和pc端公用样式及rem布局

    一:移动端准备工作<meta name="viewport" content="width=device-width, initial-scale=1.0, max ...

  6. css隐藏滚动条 兼容谷歌、火狐、IE等各个浏览器

    项目中,页面效果需要展示一个页面的移动端效果,使用的是一个苹果手机样式背景图,咋也没用过苹果,咋也不敢形容. 如下图所示: 在谷歌浏览器如图一滚动条顺利隐藏,但是火狐就如图二了,有了滚动条丑的一批. ...

  7. django基础篇04-自定义simple_tag和fitler

    自定义simple_tag app目录下创建templatetags目录 templatetags目录下创建xxpp.py 创建template对象register,注意变量名必须为register ...

  8. Insomni'hack teaser 2019 - Pwn - 1118daysober

    参考链接 https://ctftime.org/task/7459 Linux内核访问用户空间文件:get_fs()/set_fs()的使用 漏洞的patch信息 https://maltekrau ...

  9. C#基础知识之Dynamic类型

    Dynamic类型是C#4.0中引入的新类型,它允许其操作掠过编译器类型检查,而在运行时处理. 编程语言有时可以划分为静态类型化语言和动态类型化语言.C#和Java经常被认为是静态化类型的语言,而Py ...

  10. 高考数学九大超纲内容(1)wffc

    我校2016$\thicksim$2017学年度(上期)半期高三(理科)考试第12题 已知奇函数\(f(x)\)的定义域是\((-1,0)\bigcup\hspace{0.05cm}(0,1)\),\ ...