首先是生成tfrecords格式的数据,具体代码如下:

  1. #coding:utf-8
  2.  
  3. import os
  4. import tensorflow as tf
  5. from PIL import Image
  6.  
  7. cwd = os.getcwd()
  8.  
  9. '''
  10. 此处我加载的数据目录如下:
  11. bt -- 14018.jpg
  12. 14019.jpg
  13. 14020.jpg
  14.  
  15. nbt -- 1_ddd.jpg
  16. 1_dsdfs.jpg
  17. 1_dfd.jpg
  18.  
  19. 这里的bt nbt 就是类别,也就是代码中的classes
  20. '''
  21.  
  22. writer = tf.python_io.TFRecordWriter("train.tfrecords")
  23. classes = ['bt','nbt']
  24. for index, name in enumerate(classes):
  25. class_path = cwd + '/'+ name +'/' #每一类图片的目录地址
  26. for img_name in os.listdir(class_path):
  27. img_path = class_path + img_name #每一张图片的路径
  28. img = Image.open(img_path)
  29. img = img.resize((224,224))
  30. img_raw = img.tobytes() #将图片转化为原生bytes
  31. example = tf.train.Example(features = tf.train.Features(feature={
  32. 'label':tf.train.Feature(int64_list = tf.train.Int64List(value=[index])),
  33. 'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
  34. }))
  35. print "write" + ' ' + str(img_path) + "to train.tfrecords."
  36. writer.write(example.SerializeToString()) #序列化为字符串
  37. writer.close()

然后读取生成的tfrecords数据,并且将tfrecords里面的数据保存成jpg格式的图片。具体代码如下:

  1. #coding:utf-8
  2. import os
  3. import tensorflow as tf
  4. from PIL import Image
  5. cwd = '/media/project/tfLearnning/dataread/pic/'
  6. def read_and_decode(filename):
  7. #根据文件名生成一个队列
  8. filename_queue = tf.train.string_input_producer([filename])
  9.  
  10. reader = tf.TFRecordReader()
  11. _, serialized_example = reader.read(filename_queue) #返回文件名和文件
  12.  
  13. features = tf.parse_single_example(serialized_example,
  14. features={
  15. 'label':tf.FixedLenFeature([],tf.int64),
  16. 'img_raw':tf.FixedLenFeature([],tf.string),
  17. })
  18. img = tf.decode_raw(features['img_raw'],tf.uint8)
  19. img = tf.reshape(img,[224,224,3])
  20. #img = tf.cast(img,tf.float32) * (1./255) - 0.5 # 将图片变成tensor
  21. #对图片进行归一化操作将【0,255】之间的像素归一化到【-0.5,0.5】,标准化处理可以使得不同的特征具有相同的尺度(Scale)。
  22. #这样,在使用梯度下降法学习参数的时候,不同特征对参数的影响程度就一样了
  23. label = tf.cast(features['label'], tf.int32) #将标签转化tensor
  24. print img
  25. print label
  26. return img, label
  27.  
  28. #read_and_decode('train.tfrecords')
  29. img, label = read_and_decode('train.tfrecords')
  30. #print img.shape, label
  31. img_batch, label_batch = tf.train.shuffle_batch([img,label],batch_size=10,capacity=2000,min_after_dequeue=1000) #形成一个batch的数据,由于使用shuffle,因此每次取batch的时候
  32. #都是随机取的,可以使样本尽可能被充分地训练,保证min_after值小于capacit值
  33.  
  34. init = tf.global_variables_initializer()
  35.  
  36. with tf.Session() as sess:
  37. sess.run(init)
  38. # 创建一个协调器,管理线程
  39. coord = tf.train.Coordinator()
  40. # 启动QueueRunner, 此时文件名队列已经进队
  41. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  42. for i in range(10):
  43. example, l = sess.run([img, label]) #从对列中一张一张读取图片和标签
  44. #example, l = sess.run([img_batch,label_batch])
  45. print(example.shape,l)
  46.  
  47. img1=Image.fromarray(example, 'RGB') #将tensor转化成图片格式
  48. img1.save(cwd+str(i)+'_'+'Label_'+str(l)+'.jpg')#save image
  49. # 通知其他线程关闭
  50. coord.request_stop()
  51. # 其他所有线程关闭之后,这一函数才能返回
  52. coord.join(threads)

Tensorflow学习教程------tfrecords数据格式生成与读取的更多相关文章

  1. Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例

    紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取.建立网络以及模型训练整理成一个小样例,完整代码如下. #coding:utf-8 import t ...

  2. Tensorflow学习教程------过拟合

    Tensorflow学习教程------过拟合   回归:过拟合情况 / 分类过拟合 防止过拟合的方法有三种: 1 增加数据集 2 添加正则项 3 Dropout,意思就是训练的时候隐层神经元每次随机 ...

  3. Tensorflow学习教程------代价函数

    Tensorflow学习教程------代价函数   二次代价函数(quadratic cost): 其中,C表示代价函数,x表示样本,y表示实际值,a表示输出值,n表示样本的总数.为简单起见,使用一 ...

  4. tensorflow 学习教程

    tensorflow 学习手册 tensorflow 学习手册1:https://cloud.tencent.com/developer/section/1475687 tensorflow 学习手册 ...

  5. Tensorflow学习笔记----模型的保存和读取(4)

    一.模型的保存:tf.train.Saver类中的save TensorFlow提供了一个一个API来保存和还原一个模型,即tf.train.Saver类.以下代码为保存TensorFlow计算图的方 ...

  6. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  7. Tensorflow学习教程------创建图启动图

    Tensorflow作为目前最热门的机器学习框架之一,受到了工业界和学界的热门追捧.以下几章教程将记录本人学习tensorflow的一些过程. 在tensorflow这个框架里,可以讲是若数据类型,也 ...

  8. Tensorflow学习教程------非线性回归

    自己搭建神经网络求解非线性回归系数 代码 #coding:utf-8 import tensorflow as tf import numpy as np import matplotlib.pypl ...

  9. Tensorflow学习教程------利用卷积神经网络对mnist数据集进行分类_利用训练好的模型进行分类

    #coding:utf-8 import tensorflow as tf from PIL import Image,ImageFilter from tensorflow.examples.tut ...

随机推荐

  1. VM ESXi虚拟化使用学习笔记

    由于疫情原因,没有条件介绍安装部分的内容,也没有安装部分内容的相关截图,所以安装部分可以选择网上资料.但是只要熟练安装CentOS系统的,基本安装ESXi一看就会,设置主机地址方面有一定图形化界面,比 ...

  2. 吴裕雄 Bootstrap 前端框架开发——Bootstrap 字体图标(Glyphicons):glyphicon glyphicon-align-justify

    <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <meta name ...

  3. MERGE INTO:存在就更新不存在就新增——oracle

    MERGE INTO [your table-name] [rename your table here] USING ( [write your query here] )[rename your ...

  4. 【LeetCode】最长连续序列

    [问题]给定一个未排序的整数数组,找出最长连续序列的长度. 要求算法的时间复杂度为 O(n). 示例: 输入: [, , , , , ] 输出: 解释: 最长连续序列是 [, , , ].它的长度为 ...

  5. 一天一个设计模式——Adapter适配器模式(Wrapper模式)

    一.模式说明 在现实生活中,当需要将两种设备连接起来,但是两个设备的接口规范又不一致(比如电脑上只有Type-C接口,但是你的显示器是HDMI接口),这时候就需要一个适配器,适配器一端连接电脑,一端连 ...

  6. Arduino学习——u8glib提供的字体样式

    Fonts, Capital A Height4 Pixel Height  U8glib Font FontStruct5 Pixel Height  04 Font 04 Font 04 Font ...

  7. css 的基础样式--border--padding--margin

    border 边框复合写法 border:border-width border-style border-color; border-width 边框宽度 border-style 边框样式:sol ...

  8. java中执行javascript案例

    Nashorn js engine官方文档 https://docs.oracle.com/javase/7/docs/technotes/guides/scripting/programmer_gu ...

  9. Python中的常用内置对象之map对象

    如果你了解云计算的最重要的计算框架Mapreduce,你就对Python提供的map和reduce对象有很好的理解,在大数据面前,单机计算愈加力不从心,分布式计算也就是后来的云计算的框架担当大任,它提 ...

  10. dirname() 函数返回路径中的目录部分。

    定义和用法 dirname() 函数返回路径中的目录部分. 语法 dirname(path) 参数 描述 path 必需.规定要检查的路径. 说明 path 参数是一个包含有指向一个文件的全路径的字符 ...