Step 0:导入必要的库

  1. import tensorflow as tfimport os

Step 1:获取图片文件名以及对应的标签

首先是读取给定路径下所有图片的名称以及对应的标签。os.listdir(file_dir)可以列出file_dir路径下所有文件名;str.split(sep='.')将字符串str以点(.)分割。

  1. # you need to change this to your data directory
  2. train_dir = 'E:\\data\\Dog_Cat\\train\\'#Windows
  3. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'#linux
  4. #获取给定路径下图片名及其对应的标签
  5. def get_files(file_dir):
  6. '''
  7. Args:
  8. file_dir: file directory
  9. Returns:
  10. list of images and labels
  11. '''
  12. images=[]
  13. labels=[]
  14. for file in os.listdir(file_dir):
  15. name = file.split(sep='.')
  16. if name[0]=='cat':
  17. images.append(file_dir + file)
  18. labels.append(0)
  19. else:
  20. images.append(file_dir + file)
  21. labels.append(1)
  22. return images, labels

step3:分批次读取图片

由于图片数量太多,如果一次性将全部图片读入内存的话,可能会造成内存不够用的情况,因此需要分批次地将图片读入内存中。我们可以利用tensorflow的tf.train.slice_input_producer函数,利用队列的思想实现。

  1. def get_batch(image, label, image_W, image_H, batch_size, capacity):
  2. '''
  3. Args:
  4. image: list type
  5. label: list type
  6. image_W: image width
  7. image_H: image height
  8. batch_size: batch size
  9. capacity: the maximum elements in queue
  10. Returns:
  11. image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32
  12. label_batch: 1D tensor [batch_size], dtype=tf.int32
  13. '''
  14. #将python的list数据类型转换为tensorflow的数据类型
  15. #image = tf.cast(image, tf.string)
  16. #label = tf.cast(label, tf.int32)
  17.  
  18. image = tf.convert_to_tensor(image, dtype=tf.string)
  19. label = tf.convert_to_tensor(label, dtype=tf.int32)
  20.  
  21. # make an input queue 生成一个队列,shuffle=True即将图片打乱放入队列中
  22. input_queue = tf.train.slice_input_producer([image, label],shuffle=True)
  23.  
  24. label = input_queue[1] #获取label对应的队列
  25. image_contents = tf.read_file(input_queue[0])#读取图片
  26. image = tf.image.decode_jpeg(image_contents, channels=3)#解码jpg格式图片
  27.  
  28. ######################################
  29. # data argumentation should go to here
  30. ######################################
  31. #图片resize
  32. image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
  33.  
  34. # if you want to test the generated batches of images, you might want to comment the following line.
  35. # 如果想看到正常的图片,请注释掉111行(标准化)和 126行(image_batch = tf.cast(image_batch, tf.float32))
  36. # 训练时不要注释掉!
  37. #数据标准化
  38. image = tf.image.per_image_standardization(image)
  39. #Creates batches of tensors in tensors.
  40. image_batch, label_batch = tf.train.batch([image, label],
  41. batch_size= batch_size,
  42. num_threads= 2, #线程数设置
  43. capacity = capacity) #队列中最多能容纳的元素
  44.  
  45. #you can also use shuffle_batch
  46. # image_batch, label_batch = tf.train.shuffle_batch([image,label],
  47. # batch_size=BATCH_SIZE,
  48. # num_threads=64,
  49. # capacity=CAPACITY,
  50. # min_after_dequeue=CAPACITY-1)
  51.  
  52. image_batch = tf.cast(image_batch, tf.float32)
  53.  
  54. return image_batch, label_batch

首先,我们需要先把image和label转换成tensorflow的tensor相关数据类型;其次,我们需要将images和labels放入队列中,需要注意的是要设置shuffle=True将顺序打乱(默认shuffle=True)。然后通过tf.read_file和tf.image.decode_jpeg函数读取图片已经将其进行解码。接下来就是重新调整图片大小(通过crop或者pad的方式实现)和将图像归一化。最后就是利用tf.train.batch读取队列中batch_size个数的图像及其对应的标签。

测试:

  接下来就是测试上面写的代码是否正确。

  1. import matplotlib.pyplot as plt
  2.  
  3. BATCH_SIZE = 4
  4. CAPACITY = 256
  5. #图片resize后的大小
  6. IMG_W = 208
  7. IMG_H = 208
  8.  
  9. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'
  10. train_dir = 'E:\\data\\Dog_Cat\\train\\'
  11. image_list, label_list = get_files(train_dir)
  12. image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
  13.  
  14. with tf.Session() as sess:#在会话中运行程序
  15. i = 0
  16. coord = tf.train.Coordinator()#线程协调者
  17. threads = tf.train.start_queue_runners(coord=coord)
  18.  
  19. try:
  20. # Check if stop was requested.
  21. while not coord.should_stop() and i<1:
  22.  
  23. img, label = sess.run([image_batch, label_batch])
  24. print(img[0,:,:,:])
  25. # just test one batch
  26. for j in range(BATCH_SIZE):
  27. print('label: %d' %label[j])
  28. plt.imshow(img[j,:,:,:])
  29. plt.show()
  30. i+=1
  31.  
  32. except tf.errors.OutOfRangeError:#当读取完列队中所有数据时,抛出异常
  33. print('done!')
  34. finally:
  35. #Request that the threads stop.After this is called, calls to should_stop() will return True.
  36. coord.request_stop()
  37. coord.join(threads)

首先是一些参数的设置,然后通过get_files和get_batch建立等下需要运行的Graph。由于读取图片时,涉及到队列已经多线程,因此需要tf.train.Coordinator来产生一个线程协调者,主要作用是协调线程是否终止(This class implements a simple mechanism to coordinate the termination of a set of threads.),更详细的用法可以参考下文的函数介绍和官网的说明。然后调用tf.train.start_queue_runners来启动之前定义好的Graph中所有的线程。

最后的效果:

函数注释:

1)np.hstack:

函数原型:numpy.hstack(tup)

tup可以是python中的元组(tuple)、列表(list),或者numpy中数组(array),函数作用是将tup在水平方向上(按列顺序)合并。

举例:

  1. a=[1,2,3]
  2.  
  3. b=[4,5,6]
  4.  
  5. print(np.hstack((a,b)))
  6.  
  7. 输出:[1 2 3 4 5 6 ]

2)transpose()

函数原型:numpy.transpose(aaxes=None)

作用:将输入的array转置,并返回转置后的array

举例:

  1. >>> x = np.arange(4).reshape((2,2))
  2.  
  3. >>> x
  4.  
  5. array([[0, 1],
  6.  
  7. [2, 3]])
  8.  
  9. >>> np.transpose(x)
  10.  
  11. array([[0, 2],
  12.  
  13. [1, 3]])

注:

  1. image_list = ["D:\\1.jpg","D:\\2.jpg","D:\\3.jpg"]
  2. label_list = [1,0,1]
  3.  
  4. temp = np.array([image_list, label_list])
  5. print(temp)
  6. #输出:
  7. #[['D:\\1.jpg' 'D:\\2.jpg' 'D:\\3.jpg']
  8. # ['1' '0' '1']]
  9.  
  10. temp = temp.transpose()
  11. print(temp)
  12. #输出:
  13. #[['D:\\1.jpg' '1']
  14. # ['D:\\2.jpg' '0']
  15. # ['D:\\3.jpg' '1']]
  16. np.random.shuffle(temp)
  17. print(temp)
  18.  
  19. #输出:
  20. #[['D:\\2.jpg' '0']
  21. # ['D:\\1.jpg' '1']
  22. # ['D:\\3.jpg' '1']]

3)tf.cast

  1. cast(
  2.     x,
  3.     dtype,
  4.     name=None
  5. )

将x转换为dtype数据类型的张量。

举例:

  1. x = tf.constant([1.8, 2.2], dtype=tf.float32)
  2.  
  3. tf.cast(x, tf.int32) # [1, 2], dtype=tf.int32

4)tf.train.slice_input_producer

slice_input_producer(
    tensor_list,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None
)

Produces a slice of each Tensor in tensor_list.

Implemented using a Queue -- a QueueRunner for the Queue is
added to the current Graph's QUEUE_RUNNERcollection.

Args:

  • tensor_list: A list of Tensor objects. Every Tensor in tensor_list must
    have the same size in the first dimension.
  • num_epochs: An integer (optional). If specified, slice_input_producer produces
    each slice num_epochs times before generating an OutOfRange error.
    If not specified, slice_input_producer can cycle through the
    slices an unlimited number of times.
  • shuffle: Boolean. If true, the integers are randomly shuffled within each
    epoch.
  • seed: An integer (optional). Seed used if shuffle == True.
  • capacity: An integer. Sets the queue capacity.
  • shared_name: (optional). If set, this queue will be shared under the given name
    across multiple sessions.
  • name: A name for the operations (optional).

Returns:

A list of tensors, one for each element of tensor_list. If the tensor
in tensor_list has shape [N, a, b, .., z], then the
corresponding output tensor will have shape [a, b, ..., z].

Raises:

  • ValueError: if slice_input_producer produces nothing from tensor_list.

简单说来,就是生成一个队列,该队列的容量为capacity

5)tf.read_file

作用:读取输入文件的内容并输出

6)tf.image.decode_jpeg

作用:将JPEG格式编码的图片解码成uint8数据类型的tensor。

7)tf.image.resize_image_with_crop_or_pad

  1. resize_image_with_crop_or_pad(
  2.     image,
  3.     target_height,
  4.     target_width
  5. )

将图片大小调整为target_height和target_width大小。若原图像比较大,则以中心点为裁剪。若原图像比较小,则在短边补零,使得大小为target_height和target_width。

8)tf.image.per_image_standardization

线性尺度变化,使得原图像具有零均值,单位范数( zero mean and unit norm)。

也就是计算(x - mean) / adjusted_stddev,其中mean是图像中所有像素的平均值,adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))

adjusted_stddev是图像中所有像素的标准差,max作用为防止stddev的值为0。

9)tf.train.batch

  1. batch(
  2.     tensors,
  3.     batch_size,
  4.     num_threads=1,
  5.     capacity=32,
  6.     enqueue_many=False,
  7.     shapes=None,
  8.     dynamic_pad=False,
  9.     allow_smaller_final_batch=False,
  10.     shared_name=None,
  11.     name=None
  12. )

作用:Creates batches of tensors in tensors.即从输入的tensors获取batch_size大小的数据。

该函数是利用队列实现的。因此在使用的时候需要使用QueueRunner启动队列。

10)tf.train.Coordinator()

作用:线程协调者

任意一个线程可以调用coord.request_stop()来使所有线程停止。为了达到这一目的,每个线程必须定期检查coord.should_stop()。只要coord.request_stop()一被调用,那么coord.should_stop()马上返回True。

因此,一个典型的 thread running with a coordinator如下:

  1. while not coord.should_stop():
  2.  
  3. ...do some work...

11)tf.train.start_queue_runners

作用:启动graph中所有的队列。

说明:

代码来自:https://github.com/kevin28520/My-TensorFlow-tutorials,略有修改

函数作用主要参考tensorflow官网。https://www.tensorflow.org/versions/master/api_docs/

本文中修改后的代码可以在这里下载:https://github.com/hjl240/dog_vs_cat

[Kaggle] dogs-vs-cats之制作数据集[1]的更多相关文章

  1. 从零开始制作数据集所需要的所有python脚本

    最近一直在做图片数据集,积累了很多心得.我把我所使用的python脚本全部拿出来,当然这些脚本大部分网上都有,只不过比较分散. 我已经把所有代码上传到github上,觉得写的好的话,请给我一个star ...

  2. pytorch 读数据接口 制作数据集 data.dataset

    [吐槽] 啊,代码,你这个大猪蹄子 自己写了cifar10的数据接口,跟官方接口load的数据一样, 沾沾自喜,以为自己会写数据接口了 几天之后,突然想,自己的代码为啥有点慢呢,这数据集不大啊 用了官 ...

  3. pytorch人脸识别——自己制作数据集

    这是一篇面向新手的博文:因为本人也是新手,记录一下自己在做这个项目遇到的大大小小的坑. 按照下面的例子写就好了 import torch as t from torch.utils import da ...

  4. [Machine Learning] 国外程序员整理的机器学习资源大全

    本文汇编了一些机器学习领域的框架.库以及软件(按编程语言排序). 1. C++ 1.1 计算机视觉 CCV —基于C语言/提供缓存/核心的机器视觉库,新颖的机器视觉库 OpenCV—它提供C++, C ...

  5. Machine and Deep Learning with Python

    Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...

  6. 迁移学习——使用Tensorflow和VGG16预训模型进行预测

    使用Tensorflow和VGG16预训模型进行预测 from:https://zhuanlan.zhihu.com/p/28997549   fast.ai的入门教程中使用了kaggle: dogs ...

  7. Kaggle系列1:手把手教你用tensorflow建立卷积神经网络实现猫狗图像分类

    去年研一的时候想做kaggle上的一道题目:猫狗分类,但是苦于对卷积神经网络一直没有很好的认识,现在把这篇文章的内容补上去.(部分代码参考网上的,我改变了卷积神经网络的网络结构,其实主要部分我加了一层 ...

  8. [Kaggle] dogs-vs-cats之模型训练

    上一步建立好模型之后,现在就可以训练模型了. 主要代码如下: import sys #将当期路径加入系统path中 sys.path.append("E:\\CODE\\Anaconda\\ ...

  9. caffe 图片数据的转换成lmdb和数据集均值(转)

    转自网站: http://blog.csdn.net/muyiyushan/article/details/70578077 1.准备数据 使用dog/cat数据集,在训练项目根目录下分别建立trai ...

随机推荐

  1. Webpack 引入bootstrap

    Bootstrap中是一种事实上的界面标准,标准到现在的网站大量的使用它.如果可以使用webpack引入的bootstrap.css,就可以一个npm install完成项目的依赖,而不必手工的添加到 ...

  2. 使用ADO.NET查询和操作数据库

    String和StringBuilder 语法: //声明一个空的StringBuilder对象 StingBuilder对象名称 = new   StringBuilder(); //声明一个Str ...

  3. 多目标跟踪(MOT)论文随笔-POI: Multiple Object Tracking with High Performance Detection and Appearance Feature

    网上已有很多关于MOT的文章,此系列仅为个人阅读随笔,便于初学者的共同成长.若希望详细了解,建议阅读原文. 本文是tracking by detection 方法进行多目标跟踪的文章,最大的特点是使用 ...

  4. VS2017调试器无法附加到IIS进程(w3wp.exe)

    问题描述: 当使用VS2017-> 调试->附加到进程来调试IIS进程(w3wp.exe)时,报错"无法附加到进程,已附加了一个调试器" 为了解决这个问题花了不少时间, ...

  5. C第九次博客作业--指针

    一.PTA实验作业 题目1:两个4位正整数的后两位互换 1. 本题PTA提交列 2. 设计思路 3.代码截图 本题调试过程碰到问题及PTA提交列表情况说明 刚开始想到的交换是令t=a;a=b;b=t这 ...

  6. 设计模式NO.2

    设计模式NO.2 本次博客内容为第二次设计模式的练习.根据老师的要求完成下列题目: 题目1 如果需要开发一个跨平台视频播放器,可以在不同操作系统平台(如Windows.Linux.UNIX等)上播放多 ...

  7. Archlinux安装和使用技巧

    一 准备工作 1  文件下载及启动盘制作 文件可以在https://mirrors.ustc.edu.cn/,这是个中科大的镜像网,选择如下: 下载完成后,就是制作一个启动盘,我使用的是Linux下强 ...

  8. 20145237《Java程序设计》第一周学习总结

    教材学习内容总结 java可分为Java SE.Java EE.Java ME三大平台. java SE分为JVM.JRE.JDK.与java语言四个部分. JRE包括java SE API和JVM. ...

  9. Django 基本设置

    建立django目录,为了独立区分app和主站的关系,需要把app完全和主站分离 app/views.py from django.shortcuts import render from djang ...

  10. 《高级软件测试》JIRA使用手册(一)JIRA基本情况

    JIRA 官方网站为:https://www.atlassian.com/software/jira 中文代理网站为:https://www.jira.cn 现版本:v7.3.0 Atlassian公 ...