tf.train.slice_input_producer处理的是来源tensor的数据

转载自:https://blog.csdn.net/dcrmg/article/details/79776876 里面有详细参数解释

官方说明

简单使用

  1. import tensorflow as tf
  2.  
  3. images = ['img1', 'img2', 'img3', 'img4', 'img5']
  4. labels= [1,2,3,4,5]
  5.  
  6. epoch_num=8
  7.  
  8. f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=True)
  9.  
  10. with tf.Session() as sess:
  11. sess.run(tf.global_variables_initializer())
  12. coord = tf.train.Coordinator()
  13. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  14. for i in range(epoch_num):
  15. k = sess.run(f)
  16. print (i,k)
  17.  
  18. coord.request_stop()
  19. coord.join(threads)

运行结果:

用tf.data.Dataset.from_tensor_slices调用,之前的会被抛弃,用法:https://blog.csdn.net/qq_32458499/article/details/78856530

结合批处理

  1. import tensorflow as tf
  2. import numpy as np
  3.  
  4. # 样本个数
  5. sample_num=5
  6. # 设置迭代次数
  7. epoch_num = 2
  8. # 设置一个批次中包含样本个数
  9. batch_size = 3
  10. # 计算每一轮epoch中含有的batch个数
  11. batch_total = int(sample_num/batch_size)+1
  12.  
  13. # 生成4个数据和标签
  14. def generate_data(sample_num=sample_num):
  15. labels = np.asarray(range(0, sample_num))
  16. images = np.random.random([sample_num, 224, 224, 3])
  17. print('image size {},label size :{}'.format(images.shape, labels.shape))
  18.  
  19. return images,labels
  20.  
  21. def get_batch_data(batch_size=batch_size):
  22. images, label = generate_data()
  23. # 数据类型转换为tf.float32
  24. images = tf.cast(images, tf.float32)
  25. label = tf.cast(label, tf.int32)
  26.  
  27. #从tensor列表中按顺序或随机抽取一个tensor
  28. input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
  29.  
  30. image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=64)
  31. return image_batch, label_batch
  32.  
  33. image_batch, label_batch = get_batch_data(batch_size=batch_size)
  34.  
  35. with tf.Session() as sess:
  36. coord = tf.train.Coordinator()
  37. threads = tf.train.start_queue_runners(sess, coord)
  38. try:
  39. for i in range(epoch_num): # 每一轮迭代
  40. print ('************')
  41. for j in range(batch_total): #每一个batch
  42. print ('--------')
  43. # 获取每一个batch中batch_size个样本和标签
  44. image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
  45. # for k in
  46. print(image_batch_v.shape, label_batch_v)
  47. except tf.errors.OutOfRangeError:
  48. print("done")
  49. finally:
  50. coord.request_stop()
  51. coord.join(threads)

运行结果:

tf.train.slice_input_producer()的更多相关文章

  1. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  2. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  3. 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...

  4. 【转载】 tf.train.slice_input_producer()和tf.train.batch()

    原文地址: https://www.jianshu.com/p/8ba9cfc738c2 ------------------------------------------------------- ...

  5. tensorflow|tf.train.slice_input_producer|tf.train.Coordinator|tf.train.start_queue_runners

    #### ''' tf.train.slice_input_producer :定义样本放入文件名队列的方式[迭代次数,是否乱序],但此时文件名队列还没有真正写入数据 slice_input_prod ...

  6. tfsenflow队列|tf.train.slice_input_producer|tf.train.Coordinator|tf.train.start_queue_runners

      #### ''' tf.train.slice_input_producer :定义样本放入文件名队列的方式[迭代次数,是否乱序],但此时文件名队列还没有真正写入数据 slice_input_pr ...

  7. tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...

  8. Tensorflow读取大数据集的方法,tf.train.string_input_producer()和tf.train.slice_input_producer()

    1. https://blog.csdn.net/qq_41427568/article/details/85801579

  9. 深度学习原理与框架-Tfrecord数据集的读取与训练(代码) 1.tf.train.batch(获取batch图片) 2.tf.image.resize_image_with_crop_or_pad(图片压缩) 3.tf.train.per_image_stand..(图片标准化) 4.tf.train.string_input_producer(字符串入队列) 5.tf.TFRecord(读

    1.tf.train.batch(image, batch_size=batch_size, num_threads=1) # 获取一个batch的数据 参数说明:image表示输入图片,batch_ ...

随机推荐

  1. ubuntu上设置截图快捷键

    ubuntu自带的截图工具感觉能够满足基本的截图功能,可以不必安装另外的截图软件. 一般用到的截图类型有三种:全屏.当前活动窗口.自定义区域,其中自定义区域截图是最灵活也是我们用的最多的方式.在ubu ...

  2. 【JZOJ4743】【NOIP2016提高A组模拟9.2】积木

    题目描述 输入 输出 样例输入 3 8 7 6 3 9 4 1 10 5 输出 18 数据范围 样例解释 解法 容易从n<=15得出可以使用状态压缩动态规划. 设f[i][j][k]表示01状态 ...

  3. Eclipse中提示 找不到类 javax.servlet.http.HttpServletResponse

    问题如题, 解决方案如下: 复制tomcat的安装路径下\lib\servlet-api.jar 到WEB-INF/lib下即可.

  4. Length of Last Word输出最后单词的字母个数

    Given a string s consists of upper/lower-case alphabets and empty space characters ' ', return the l ...

  5. 【JZOJ4841】【NOIP2016提高A组集训第4场11.1】平衡的子集

    题目描述 夏令营有N个人,每个人的力气为M(i).请大家从这N个人中选出若干人,如果这些人可以分成两组且两组力气之和完全相等,则称为一个合法的选法,问有多少种合法的选法? 数据范围 40%的数据满足: ...

  6. java.lang.StackOverflowError 解决办法

    java.lang.StackOverflowError com.sxt.servlet.servlet1.LoginServlet.doGet(LoginServlet.java:15) com.s ...

  7. Chrome的使用技巧总结

    设置一>” 首先设置打开特定的网页 设置--> “下载位置” (Ctrl+H),快速查找自己浏览器访问网页的历史记录 (Ctrl+D),将目前认为比较好的网页保存. ctrl+t  新建标 ...

  8. 《C程序设计语言》笔记(二)

    四:函数与程序结构 1:函数之间的通信可以通过参数.函数返回值以及外部变量进行. 2:如果函数定义中省略了返回值类型,则默认为int类型.如果没有函数原型,则函数将在第一次出现的表达式中被隐式声明,比 ...

  9. 亚洲唯一,阿里云SLB位列Gartner全球网络负载均衡市场前五

    近日,Gartner发布了最新的全球企业级网络设备市场份额报告“Market Share: Enterprise Network Equipment by Market Segment, Worldw ...

  10. spring-jpa通过自定义sql执行修改碰到的问题

    在编写自定义SQL的时候需要注意 @Query 注解只能用来查询,想要进行添加.修改和删除操作需要配合 @Modifying 注解一同使用 @Modifying @Query("update ...