一、我们都知道Python由于GIL的原因导致多线程并不是真正意义上的多线程。但是TensorFlow在做多线程使用的时候是吧GIL锁放开了的。所以TensorFlow是真正意义上的多线程。这里我们主要是介绍queue式的多线程运行方式。

  二、了解一下TensorFlow多线程queue的使用过程

  1.   tensorflow:
  2. 多线程是真正的多线程执行
  3. 队列:
  4. tf.FIFOQueue(<capacity>, <dtypes>, <name>), 先进先出
  5. tf.RandomShuffleQueue, 随机出队列
  6. 多线程:
  7. 当数据量很大时,入队操作从硬盘中读取,放入内存。主线程需要等待操作完成,才能训练。
  8. 使用多线程,可以达到边处理,边训练的异步效果。
  9. 队列管理器(弃用):
  10. tf.train.QueueRunner(<queue>, <enqueue_ops>)
  11. enqueue_ops: 添加线程的队列操作列表[]*2为开启2个线程,[]内为操作部分
  12. method:
  13. create_threads(<sess>, <coord>, <start>):
  14. 创建线程来运行给定的入队操作。
  15. start: 布尔值,是否启动线程
  16. coord: 线程协调器
  17. return: 线程实例
  18. 线程协调器:
  19. 协调线程之间终止

  注意:这里使用的是TensorFlow1.0版本,在后续的版本中基本 把这种方式废弃了。但是这里为了好的了解文件读取的方式,我们使用queue式的多线程来执行。

  1. import tensorflow as tf
  2.  
  3. def queue_demo():
  4.  
  5. # 1、声明队列
  6. queue = tf.FIFOQueue(3, dtypes=tf.float32)
  7.  
  8. # 2、加入数据
  9. init_queue = queue.enqueue_many([[0.1, 0.2, 0.3]])
  10.  
  11. # 3、取出数据
  12. data = queue.dequeue()
  13.  
  14. # 4、处理数据
  15. en_queue = queue.enqueue(data + 1)
  16.  
  17. with tf.Session() as sess:
  18. # 初始化操作
  19. sess.run(init_queue)
  20. # 循环
  21. for i in range(10):
  22. sess.run(en_queue)
  23. for i in range(queue.size().eval()):
  24. print(queue.dequeue().eval())
  1. import tensorflow as tf
  2.  
  3. def queue_thread_demo():
  4. # 1、声明队列
  5. queue = tf.FIFOQueue(100, dtypes=tf.float32)
  6.  
  7. # 2、加入数据
  8. for i in range(100):
  9. queue.enqueue((i + 1)/100)
  10.  
  11. # 3、操作
  12. data = queue.dequeue()
  13. en_queue = queue.enqueue(data + 1)
  14.  
  15. # 3、定义队列管理器
  16. qr = tf.train.QueueRunner(queue, enqueue_ops=[en_queue] * 2)
  17.  
  18. with tf.Session() as sess:
  19. # 开启线程协调器
  20. coord = tf.train.Coordinator()
  21. # 开启线程
  22. threads = qr.create_threads(sess, coord=coord, start=True)
  23. for i in range(100):
  24. print(sess.run(queue.dequeue()))
  25. # 注:没有线程协调器,主线程结束,会结束session,导致异常。
  26. coord.request_stop()
  27. coord.join(threads)  

  三、了解基本的数据读取过程和api

  1.   文件io:
  2. 1csv文件读取一行
  3. 2、二进制文件指定bytes
  4. 3、图片文件一张一张
  5. 流程:
  6. 1、构造一个文件队列
  7. 2、读取文件内容
  8. 3、解码文件内容
  9. 4、批处理
  10. api:
  11. 1、文件队列构造
  12. tf.train.string_input_producer(<string_tensor>, <shuffle=True>)
  13. string_tensor: 含有文件名的一阶张量
  14. num_epochs: 过几遍数据,默认无数遍
  15. 2、文件阅读器
  16. tf.TextLineReadercsv文件格式类型
  17. tf.FixedLengthRecordReader(record_bytes)、读取固定值的二进制文件
  18. tf.TFRecordReader、读取TfRecords
  19. 共同:
  20. read(file_queue): 队列中指定数量
  21. return: Tensors 元组(key:文件名, value默认行内容)
  22. 3、文件解码器:
  23. tf.decode_csv(<records>, <record_defaults=None>, <field_delim=None>, <name=None>)
  24. CSV转换为张量,与tf.TextLineReader搭配使用
  25. records: tensor型字符串,每一个字符串为CSV中的记录
  26. record_defaults: 参数决定了所有张量的类型,并设置一个值在输入字符串中缺少使用默认值
  27. tf.decode_raw(<bytes>, <out_type>, <little_endian=None>, <name=None>)
  28. 将字节转换为一个向量表示,字节为一字符串类型的张量,与函数tf.FixedLengthRecordReader搭配使用,二进制读取为utf-8格式 

  在读取文件之间先了解批处理的作用,主要是讲每次读出来的数据,缓存,然后到达一个批次,统一训练

  1. 管道读端批处理:
  2. tf.train.batch(<tensors>, <batch_size>, <num_threads=1>, <capacity=32>, <name=None>)
  3. tensors: 张量列表
  4. tf.train.shuffle_batch(<tensors>, <batch_size>, <capacity>, <min_dequeue>)
  5. min_dequeue: 留下队列里的张量个数,能够保持随机打乱

  四、csv文件读取

  1.     csv文件读取:
  2. 1、找到文件,构建列表
  3. 2、构造文件队列
  4. 3、构造阅读器,读取队列内容
  5. 4、解码内容
  6. 5、批处理
  1. import os
  2. import tensorflow as tf
  3.  
  4. def csv_io():
  5. # 1、找到文件,加入队列
  6. file_names = os.listdir("data/csv")
  7. file_list = [os.path.join("data/csv", file_name) for file_name in file_names]
  8. file_queue = tf.train.string_input_producer(file_list)
  9. # 2、读取一行数据
  10. reader = tf.TextLineReader()
  11. key, value = reader.read(file_queue)
  12. # 3、解码csv
  13. records = [[-1], [-1]]
  14. num1, num2 = tf.decode_csv(value, record_defaults=records)
  15. # 4、批处理
  16. num1_batch, num2_batch = tf.train.batch([num1, num2], batch_size=9, num_threads=1, capacity=9)
  17.  
  18. with tf.Session() as sess:
  19. # 加入线程协调器
  20. coord = tf.train.Coordinator()
  21. # 线程运行
  22. threads = tf.train.start_queue_runners(sess, coord=coord)
  23. print(sess.run([num1_batch, num2_batch]))
  24.  
  25. # 子线程回收
  26. coord.request_stop()
  27. coord.join(threads)

  五、图片文件读取

  1. 图片读取:
  2. 每一个样本必须保证特征数量一样
  3. 特征值:像素值
  4. 单通道:灰度值(黑白图片,像素中只有一个值)
  5. 三通道:RGB(每个像素都有3个值)
  6. 三要素:长度宽度、通道值
  7. 图像的基本操作:
  8. 目的:
  9. 1、增加图片数据的统一性
  10. 2、所有图片装换成指定大小
  11. 3、缩小图片数据量,防止增加开销
  12. 操作:
  13. 缩小图片大小
  14. api
  15. 图片缩放:
  16. tf.image.resize_images(<images>, <size>)
  17. <images>:4-D形状[batch, height, width, channels]/3-D[height, width, channels]
  18. <size>:1-D int32张量:new_height, new_width, 图像的新尺寸
  19. return4-D/3-D格式图片
  20. 图片读取api
  21. tf.WholeFileReader:
  22. 将文件的全部内容作为输入的读取器
  23. return:读取器实例
  24. read(<file_queue>):输出将一个文件名(key)和该文件的内容值
  25. 图像解码器:
  26. tf.image.decode_jpeg(<contents>):
  27. JPEG编码的图像解码为unit8张量
  28. returnuint8张量,3-D形状[height, width, channels]
  29. tf.image.decode_png():
  30. PNG编码的图像解码为uint8/uint16的张量
  31. return:张量类型,3-D[height, width, channels]
  1. import os
  2. import tensorflow as tf
  3.  
  4. def image_io():
  5. # 1、读取文件放入队列
  6. image_names = os.listdir("data/image")
  7. image_files = [os.path.join("data/image", image_name) for image_name in image_names]
  8. image_queue = tf.train.string_input_producer(image_files)
  9.  
  10. # 2、读取一张图片数据
  11. reader = tf.WholeFileReader()
  12. # value:一整张图片的数据
  13. key, value = reader.read(image_queue)
  14.  
  15. # 3、解码
  16. image = tf.image.decode_jpeg(value)
  17. print(image)
  18.  
  19. # 4、处理图片的大小
  20. new_image = tf.image.resize_images(image, [350, 350])
  21. print(new_image)
  22. # 注意一定要固定形状,批处理的时候所有数据必须固定
  23. new_image.set_shape([350, 350, 3])
  24. print(new_image)
  25.  
  26. # 5、批处理
  27. image_batch = tf.train.batch([new_image], batch_size=2, num_threads=1, capacity=2)
  28.  
  29. # 6、运行
  30. with tf.Session() as sess:
  31. # 加入线程协调器
  32. coord = tf.train.Coordinator()
  33. # 线程运行
  34. threads = tf.train.start_queue_runners(sess, coord=coord)
  35. print(sess.run([image_batch]))
  36.  
  37. # 子线程回收
  38. coord.request_stop()
  39. coord.join(threads)

  六、二进制文件读取

  1.   二进制文件读取:
  2. api:
  3. tf.FixedLengthRecordReader(<record_bytes>)
  4. record_bytes:数据长度
  5. 解码器:
  6. tf.decode_raw(<bytes>, <out_type>, <little_endian=None>, <name=None>)
  7. bytes:数据
  8. out_type:输出类型
  1. import os
  2. import tensorflow as tf
  3.  
  4. def cifar_io():
  5. # 1、读取文件加入队列
  6. cifar_names = os.listdir("data/cifar")
  7. cifar_files = [os.path.join("data/cifar", cifar_name) for cifar_name in cifar_names if cifar_name.endswith(".bin") and cifar_name != "test_batch.bin"]
  8. file_queue = tf.train.string_input_producer(cifar_files)
  9.  
  10. # 2、读取二进制文件
  11. reader = tf.FixedLengthRecordReader(record_bytes=(32 * 32 * 3 + 1))
  12. key, value = reader.read(file_queue)
  13.  
  14. # 3、解码数据(二进制数据)
  15. # 样本数据集根据具体数据处理,这里的数据为第一个数据为目标值,后面的为图片数据
  16. target_image = tf.decode_raw(value, tf.uint8)
  17.  
  18. # 4、分割数据
  19. target = tf.slice(target_image, [0], [1])
  20. image = tf.slice(target_image, [1], [32 * 32 * 3])
  21.  
  22. # 5、特征数据形状改变
  23. new_image = tf.reshape(image, [32, 32, 3])
  24. print(new_image)
  25.  
  26. # 6、批处理
  27. image_batch, target_batch = tf.train.batch([new_image, target], batch_size=10, capacity=10)
  28. print(image_batch, target_batch)
  29.  
  30. # 7、运行
  31. with tf.Session() as sess:
  32. # 线程协调器
  33. coord = tf.train.Coordinator()
  34. # 线程运行
  35. threads = tf.train.start_queue_runners(sess, coord=coord)
  36. print(sess.run([image_batch, target_batch]))
  37.  
  38. # 子线程回收
  39. coord.request_stop()
  40. coord.join(threads)

  七、上面说完了,常用文件读取的方式,下面说一下TensorFlow文件的存储与读取的方式。TensorFlow一般采用*.threcords文件格式进行保存。它是一种内置文件格式,是一种二进制文件,它可以更好的利用内存,更方便的复制和移动。

  1.   tf.TFRecordReader
  2. 一种内置文件格式,是一种二进制文件,它可以更好的利用内存,更方便的复制和移动
  3. 为了将二进制数据和标签(训练类别标签),数据存储在同一文件中
  4. 分析、存取
  5. 文件格式:*.threcords
  6. 写入文件内容:example协议块
  7. TF存储:
  8. TFRecord存储器
  9. tf.python_io.TFRecordWriter(<path>)
  10. method:
  11. write(record)
  12. close
  13. Example协议块:
  14. tf.train.Example(<features=None>)
  15. features:tf.train.Features(<feature=None>)实例
  16. feature:字典数据,key为要保存的数据
  17. tf.train.Feature(<**options>)
  18. **options:
  19. tf.train.ByteList(<value=[Bytes]>)
  20. tf.train.IntList(<value=[Value]>)
  21. tf.train.FloatList(<value=[Value]>)
  22. return:Features实例
  23. return:Example协议块
  24. TF读取:
  25. tf.parse_example(<serialized>, <features=None>, <name=None>)
  26. serialized:标量字符串Tensor,一个序列化的Example
  27. features:dict字典数据,键为读取的名字,值为FixedLenFeature
  28. return:一个键值对组成的字典,键为读取的名字
  29. tf.FixedLenFeature(<shape>, <dtype>)
  30. shape:形状
  31. dtype:数据类型(float32/int64/string
  1. import os
  2. import tensorflow as tf
  3.  
  4. def tf_records_io():
  5. # 1、读取文件加入队列
  6. cifar_names = os.listdir("data/cifar")
  7. cifar_files = [os.path.join("data/cifar", cifar_name) for cifar_name in cifar_names if
  8. cifar_name.endswith(".bin") and cifar_name != "test_batch.bin"]
  9. file_queue = tf.train.string_input_producer(cifar_files)
  10.  
  11. # 2、读取二进制文件
  12. reader = tf.FixedLengthRecordReader(record_bytes=(32 * 32 * 3 + 1))
  13. key, value = reader.read(file_queue)
  14.  
  15. # 3、解码数据(二进制数据)
  16. # 样本数据集根据具体数据处理,这里的数据为第一个数据为目标值,后面的为图片数据
  17. target_image = tf.decode_raw(value, tf.uint8)
  18.  
  19. # 4、分割数据
  20. target = tf.slice(target_image, [0], [1])
  21. image = tf.slice(target_image, [1], [32 * 32 * 3])
  22.  
  23. # 5、特征数据形状改变
  24. new_image = tf.reshape(image, [32, 32, 3])
  25. print(new_image)
  26.  
  27. # 6、批处理
  28. image_batch, target_batch = tf.train.batch([new_image, target], batch_size=10, capacity=10)
  29. print(image_batch, target_batch)
  30.  
  31. # 7、tf文件写入
  32. with tf.Session() as sess:
  33. if not os.path.exists("data/tf_records/cifar.tfrecords"):
  34. # 1)存进tfRecords文件
  35. print("开始存储")
  36. with tf.python_io.TFRecordWriter(path="data/tf_records/cifar.tfrecords") as writer:
  37. # 2)循环次数为批次数
  38. for i in range(10):
  39. # 获取对应值
  40. image_data = image_batch[i].eval().tostring()
  41. target_data = int(target_batch[i].eval()[0])
  42. # 3)产生实例
  43. example = tf.train.Example(features=tf.train.Features(feature={
  44. "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
  45. "target": tf.train.Feature(int64_list=tf.train.Int64List(value=[target_data]))
  46. }))
  47. # 4)写入数据
  48. writer.write(example.SerializeToString())
  49. print("结束存储")
  50.  
  51. # 8、tf文件读取
  52. # 1)读取tfRecords文件
  53. tf_queue = tf.train.string_input_producer(["data/tf_records/cifar.tfrecords"])
  54.  
  55. # 2)读取数据
  56. tf_reader = tf.TFRecordReader()
  57. key, value = tf_reader.read(tf_queue)
  58.  
  59. # 3)解析example
  60. features = tf.parse_single_example(value, features={
  61. "image": tf.FixedLenFeature([], dtype=tf.string),
  62. "target": tf.FixedLenFeature([], dtype=tf.int64)
  63. })
  64. print(features["image"], features["target"])
  65.  
  66. # 4)解码数据
  67. image = tf.decode_raw(features["image"], tf.uint8)
  68. image_reshape = tf.reshape(image, [32, 32, 3])
  69. target = tf.cast(features["target"], tf.int32)
  70. print(image_reshape, target)
  71. # 5)批处理
  72. image_batch, target_batch = tf.train.batch([image_reshape, target], batch_size=10, capacity=10)
  73.  
  74. # 9、运行
  75. with tf.Session() as sess:
  76. # 线程协调器
  77. coord = tf.train.Coordinator()
  78. # 线程运行
  79. threads = tf.train.start_queue_runners(sess, coord=coord)
  80.  
  81. # tf文件读取
  82. print(sess.run([image_batch, target_batch]))
  83.  
  84. # 子线程回收
  85. coord.request_stop()
  86. coord.join(threads)

  八、总结,说起来文件读取只是读取各种数据样本的开始,这里的几种读取方式基本上就是常用的几种形式了。目的是认识常规数据读取的方式。

    但是这里要说明:现在处理数据的方式一般采用tf.data的api来进行数据的处理和调整。所以需要把精力放在tf.data上面。

Python之TensorFlow的数据的读取与存储-2的更多相关文章

  1. python 用codecs实现数据的读取

    import numpy as np import codecs f=codecs.open('testsklearn.txt','r','utf-8').readlines() print(f) d ...

  2. python数据分析之:数据加载,存储与文件格式

    前面介绍了numpy和pandas的数据计算功能.但是这些数据都是我们自己手动输入构造的.如果不能将数据自动导入到python中,那么这些计算也没有什么意义.这一章将介绍数据如何加载以及存储. 首先来 ...

  3. 利用Python进行数据分析_Pandas_数据加载、存储与文件格式

    申明:本系列文章是自己在学习<利用Python进行数据分析>这本书的过程中,为了方便后期自己巩固知识而整理. 1 pandas读取文件的解析函数 read_csv 读取带分隔符的数据,默认 ...

  4. Swift - .plist文件数据的读取和存储

    每次在Xcode中新建一个iOS项目后,都会自己产生一个.plist文件,里面记录项目的一些配置信息.我们也可以自己创建.plist文件来进行数据的存储和读取. .plist文件其实就是一个XML格式 ...

  5. Pandas_数据读取与存储数据(全面但不精炼)

    Pandas 读取和存储数据 目录 读取 csv数据 读取 txt数据 存储 csv 和 txt 文件 读取和存储 json数据 读取和存储 excel数据 一道练习题 参考 Numpy基础(全) P ...

  6. Pandas_数据读取与存储数据(精炼)

    # 一,读取 CSV 文件: # 文字解析函数: # pd.read_csv() 从文件中加载带分隔符的数据,默认分隔符为逗号 # pd.read_table() 从文件中加载带分隔符的数据,默认分隔 ...

  7. 【Python入门只需20分钟】从安装到数据抓取、存储原来这么简单

    基于大众对Python的大肆吹捧和赞赏,作为一名Java从业人员,我本着批判与好奇的心态买了本python方面的书<毫无障碍学Python>.仅仅看了书前面一小部分的我......决定做一 ...

  8. Tensorflow学习-数据读取

    Tensorflow数据读取方式主要包括以下三种 Preloaded data:预加载数据 Feeding: 通过Python代码读取或者产生数据,然后给后端 Reading from file: 通 ...

  9. 『TensorFlow』数据读取类_data.Dataset

    一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...

随机推荐

  1. mysql rtrim() 函数

    mysql> select rtrim(" cdcdcd "); +--------------------+ | rtrim(" cdcdcd ") | ...

  2. springboot注解方式使用redis缓存

    引入依赖库 在pom中引入依赖库,如下 <dependency> <groupId>org.springframework.boot</groupId> <a ...

  3. Kubeadm证书过期时间调整

    kubeadm 默认证书为一年,一年过期后,会导致api service不可用,使用过程中会出现:x509: certificate has expired or is not yet valid. ...

  4. 单细胞数据整合方法 | Comprehensive Integration of Single-Cell Data

    操作代码:https://satijalab.org/seurat/ 依赖的算法 CCA CANONICAL CORRELATION ANALYSIS | R DATA ANALYSIS EXAMPL ...

  5. windows正常,linux报错:'PHPExcel_Reader_excel2007' not found

    原因:因为在linux下,大小写敏感 我的文件夹命名是大写,在window小写可以访问到,但是在linux就大小写敏感导致没找到文件没导入成功 导入文件的路径(错误)import('phpexcel. ...

  6. Could not get JDBC Connection; nested exception is java.sql.SQLException: ${jdbc.driver}

    在一个SSM分布式项目中一个服务报错: ### Error querying database. Cause: org.springframework.jdbc.CannotGetJdbcConnec ...

  7. (转载)PyTorch代码规范最佳实践和样式指南

    A PyTorch Tools, best practices & Styleguide 中文版:PyTorch代码规范最佳实践和样式指南 This is not an official st ...

  8. python 中requests的返回数可直接使用json

    对Python的request不是很了解,在使用时才发现,可以把request的请求结果,直接使用.json()['字段名']方法进行一个取值,案例如下 def test_tiantian(self) ...

  9. EasyNVR网页Chrome无插件播放摄像机视频功能二次开发之云台控制接口示例代码

    随着多媒体技术和网络通信技术的迅速发展,视频监控技术在电力系统.电信行业.工业监控.工地.城市交通.水利系统.社区安防等领域得到越来越广泛的应用.摄像头直播视频监控通过网络直接连接,可达到的世界任何角 ...

  10. 手机端rem无限适配

    参考文档: http://blog.csdn.net/xwqqq/article/details/54862279 https://github.com/amfe/lib-flexible/tree/ ...