https://blog.csdn.net/lujiandong1/article/details/53991373

方式一:不显示设置读取N个epoch的数据,而是使用循环,每次从训练的文件中随机读取一个batch_size的数据,直至最后读取的数据量达到N个epoch。说明,这个方式来实现epoch的输入是不合理。不是说每个样本都会被读取到的。

对于这个的解释,从数学上解释,比如说有放回的抽样,每次抽取一个样本,抽取N次,总样本数为N个。那么,这样抽取过一轮之后,该样本也是会有1/e的概率没有被抽取到。所以,如果使用这种方式去训练的话,理论上是没有用到全部的数据集去训练的,很可能会造成过拟合的现象。

我做了个小实验验证:

  1.  
    import tensorflow as tf
  2.  
    import numpy as np
  3.  
    import datetime,sys
  4.  
    from tensorflow.contrib import learn
  5.  
    from model import CCPM
  6.  
     
  7.  
    training_epochs = 5
  8.  
    train_num = 4
  9.  
    # 运行Graph
  10.  
    with tf.Session() as sess:
  11.  
     
  12.  
    #定义模型
  13.  
    BATCH_SIZE = 2
  14.  
    # 构建训练数据输入的队列
  15.  
    # 生成一个先入先出队列和一个QueueRunner,生成文件名队列
  16.  
    filenames = ['a.csv']
  17.  
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
  18.  
    # 定义Reader
  19.  
    reader = tf.TextLineReader()
  20.  
    key, value = reader.read(filename_queue)
  21.  
    # 定义Decoder
  22.  
    # 编码后的数据字段有24,其中22维是特征字段,2维是lable字段,label是二分类经过one-hot编码后的字段
  23.  
    #更改了特征,使用不同的解析参数
  24.  
    record_defaults = [[1]]*5
  25.  
    col1,col2,col3,col4,col5 = tf.decode_csv(value,record_defaults=record_defaults)
  26.  
    features = tf.pack([col1,col2,col3,col4])
  27.  
    label = tf.pack([col5])
  28.  
     
  29.  
    example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=BATCH_SIZE, capacity=20000, min_after_dequeue=4000, num_threads=2)
  30.  
     
  31.  
    sess.run(tf.initialize_all_variables())
  32.  
    coord = tf.train.Coordinator()#创建一个协调器,管理线程
  33.  
    threads = tf.train.start_queue_runners(coord=coord)#启动QueueRunner, 此时文件名队列已经进队。
  34.  
    #开始一个epoch的训练
  35.  
    for epoch in range(training_epochs):
  36.  
    total_batch = int(train_num/BATCH_SIZE)
  37.  
    #开始一个epoch的训练
  38.  
    for i in range(total_batch):
  39.  
    X,Y = sess.run([example_batch, label_batch])
  40.  
    print X,':',Y
  41.  
    coord.request_stop()
  42.  
    coord.join(threads)

toy data a.csv:

说明:输出如下,可以看出并不是每个样本都被遍历5次,其实这样的话,对于DL的训练会产生很大的影响,并不是每个样本都被使用同样的次数。

方式二:显示设置epoch的数目

  1.  
    #-*- coding:utf-8 -*-
  2.  
    import tensorflow as tf
  3.  
    import numpy as np
  4.  
    import datetime,sys
  5.  
    from tensorflow.contrib import learn
  6.  
    from model import CCPM
  7.  
     
  8.  
    training_epochs = 5
  9.  
    train_num = 4
  10.  
    # 运行Graph
  11.  
    with tf.Session() as sess:
  12.  
     
  13.  
    #定义模型
  14.  
    BATCH_SIZE = 2
  15.  
    # 构建训练数据输入的队列
  16.  
    # 生成一个先入先出队列和一个QueueRunner,生成文件名队列
  17.  
    filenames = ['a.csv']
  18.  
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True,num_epochs=training_epochs)
  19.  
    # 定义Reader
  20.  
    reader = tf.TextLineReader()
  21.  
    key, value = reader.read(filename_queue)
  22.  
    # 定义Decoder
  23.  
    # 编码后的数据字段有24,其中22维是特征字段,2维是lable字段,label是二分类经过one-hot编码后的字段
  24.  
    #更改了特征,使用不同的解析参数
  25.  
    record_defaults = [[1]]*5
  26.  
    col1,col2,col3,col4,col5 = tf.decode_csv(value,record_defaults=record_defaults)
  27.  
    features = tf.pack([col1,col2,col3,col4])
  28.  
    label = tf.pack([col5])
  29.  
     
  30.  
    example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=BATCH_SIZE, capacity=20000, min_after_dequeue=4000, num_threads=2)
  31.  
    sess.run(tf.initialize_local_variables())
  32.  
    sess.run(tf.initialize_all_variables())
  33.  
    coord = tf.train.Coordinator()#创建一个协调器,管理线程
  34.  
    threads = tf.train.start_queue_runners(coord=coord)#启动QueueRunner, 此时文件名队列已经进队。
  35.  
    try:
  36.  
    #开始一个epoch的训练
  37.  
    while not coord.should_stop():
  38.  
    total_batch = int(train_num/BATCH_SIZE)
  39.  
    #开始一个epoch的训练
  40.  
    for i in range(total_batch):
  41.  
    X,Y = sess.run([example_batch, label_batch])
  42.  
    print X,':',Y
  43.  
    except tf.errors.OutOfRangeError:
  44.  
    print('Done training')
  45.  
    finally:
  46.  
    coord.request_stop()
  47.  
    coord.join(threads)

说明:输出如下,可以看出每个样本都被访问5次,这才是合理的设置epoch数据的方式。


http://stats.stackexchange.com/questions/242004/why-do-neural-network-researchers-care-about-epochs

说明:这个博客也在探讨,为什么深度网络的训练中,要使用epoch,即要把训练样本全部过一遍.而不是随机有放回的从里面抽样batch_size个样本.在博客中,别人的实验结果是如果采用有放回抽样的这种方式来进行SGD的训练.其实网络见不到全部的数据集,推导过程如上所示.所以,网络的收敛速度比较慢.

tesnorflow实现N个epoch训练数据读取的办法的更多相关文章

  1. tensorflow读取训练数据方法

    1. 预加载数据 Preloaded data # coding: utf-8 import tensorflow as tf # 设计Graph x1 = tf.constant([2, 3, 4] ...

  2. TensorFlow Distribution(分布式中的数据读取和训练)

    本文目的 在介绍estimator分布式的时候,官方文档由于版本更新导致与接口不一致.具体是:在estimator分布式当中,使用dataset作为数据输入,在1.12版本中,数据训练只是datase ...

  3. TensorFlow实践笔记(一):数据读取

    本文整理了TensorFlow中的数据读取方法,在TensorFlow中主要有三种方法读取数据: Feeding:由Python提供数据. Preloaded data:预加载数据. Reading ...

  4. 『TensorFlow』数据读取类_data.Dataset

    一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...

  5. tensorflow之数据读取探究(1)

    Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...

  6. TensorFlow数据读取方式:Dataset API

    英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...

  7. tensoflow数据读取

    数据读取 TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFl ...

  8. TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取

    TensorFlow 的 How-Tos,讲解了这么几点: 1. 变量:创建,初始化,保存,加载,共享: 2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Vis ...

  9. 详解Tensorflow数据读取有三种方式(next_batch)

    转自:https://blog.csdn.net/lujiandong1/article/details/53376802 Tensorflow数据读取有三种方式: Preloaded data: 预 ...

随机推荐

  1. sqlserver 2012 查询时提示“目录名称无效”

    重装系统或者用360等软件清理了相应的临时文件导致解决:在运行中输入 %temp% 回车,会跳出找不到路径的提示,然后到提示的目录建没有找到的目录文件夹即可.

  2. 使用position:relative制作下边框下的小三角

    在制作tab选项卡的时候,有时会有下边框,且下边框下另一个头向下的小三角,这全然能够用css来实现,而不必使用背景图片. 由于使用背景图片时会有一个问题,选项卡内容字数不同.导致使用背景图片时无法控制 ...

  3. C#实现路由器断开连接,更改公网ip

    publicstaticvoid Disconnect() { string url ="断 线";    string uri ="http://192.168.1.1 ...

  4. [Bug]Unable to start process dotnet.exe

    This morning I did a sync of a repo using of Visual Studio and then tried to run a web application I ...

  5. How to update WPF browser application manifest and xbap file with ‘mage.exe’

    老外参考文章1 老外参考文章2 I created a WPF browser application MyApp then published it by ClickOnce in VS2008. ...

  6. DevExpress VCL for Delphi 各版本收集下载

    更多VCL组件请到:http://maxwoods.400gb.com/u/758954/1974711 DevExpress VCL 5.7:http://www.ctdisk.com/file/7 ...

  7. 【mybatis】【mysql】mybatis查询mysql,group by分组查询报错:Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated column

    mybatis查询mysql,group by分组查询报错:Expression #1 of SELECT list is not in GROUP BY clause and contains no ...

  8. android studio# jdk8# class file for java.lang.invoke.MethodType not found

    https://github.com/evant/gradle-retrolambda/issues/23 class file for java.lang.invoke.MethodType not ...

  9. Glibc 和 uClibc

    转自:https://blog.csdn.net/clirus/article/details/50145959?locationNum=4 最近在搞mips openwrt框架的东西,mipc的GC ...

  10. 采用redis 主从架构的原因

    如果系统的QPS超过10W+,甚至是百万以上的访问,则光是Redis是不够的,但是Redis是整个大型缓存架构中,支撑高并发的架构非常重要的环节. 首先,你的缓存中间件.缓存系统,必须能够支撑起10w ...