Tensorflow学习教程------过拟合

 

回归:过拟合情况

/

分类过拟合

防止过拟合的方法有三种:

1 增加数据集

2 添加正则项

3 Dropout,意思就是训练的时候隐层神经元每次随机抽取部分参与训练。部分不参与

最后对之前普通神经网络分类mnist数据集的代码进行优化,初始化权重参数的时候采用截断正态分布,偏置项加常数,采用dropout防止过拟合,加三层隐层神经元,最后的准确率达到97%以上。代码如下

  1. # coding: utf-8
  2.  
  3. # 微信公众号:深度学习与神经网络
  4. # Github:https://github.com/Qinbf
  5. # 优酷频道:http://i.youku.com/sdxxqbf
  6.  
  7. import tensorflow as tf
  8. from tensorflow.examples.tutorials.mnist import input_data
  9.  
  10. #载入数据集
  11. mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
  12.  
  13. #每个批次的大小
  14. batch_size = 100
  15. #计算一共有多少个批次
  16. n_batch = mnist.train.num_examples // batch_size
  17.  
  18. #定义两个placeholder
  19. x = tf.placeholder(tf.float32,[None,784])
  20. y = tf.placeholder(tf.float32,[None,10])
  21. keep_prob=tf.placeholder(tf.float32)
  22.  
  23. #创建一个简单的神经网络
  24. W1 = tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
  25. b1 = tf.Variable(tf.zeros([2000])+0.1)
  26. L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
  27. L1_drop = tf.nn.dropout(L1,keep_prob)
  28.  
  29. W2 = tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
  30. b2 = tf.Variable(tf.zeros([2000])+0.1)
  31. L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
  32. L2_drop = tf.nn.dropout(L2,keep_prob)
  33.  
  34. W3 = tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
  35. b3 = tf.Variable(tf.zeros([1000])+0.1)
  36. L3 = tf.nn.tanh(tf.matmul(L2_drop,W3)+b3)
  37. L3_drop = tf.nn.dropout(L3,keep_prob)
  38.  
  39. W4 = tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
  40. b4 = tf.Variable(tf.zeros([10])+0.1)
  41. prediction = tf.nn.softmax(tf.matmul(L3_drop,W4)+b4)
  42.  
  43. #二次代价函数
  44. # loss = tf.reduce_mean(tf.square(y-prediction))
  45. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
  46. #使用梯度下降法
  47. train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
  48.  
  49. #初始化变量
  50. init = tf.global_variables_initializer()
  51.  
  52. #结果存放在一个布尔型列表中
  53. correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
  54. #求准确率
  55. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  56.  
  57. with tf.Session() as sess:
  58. sess.run(init)
  59. for epoch in range(31):
  60. for batch in range(n_batch):
  61. batch_xs,batch_ys = mnist.train.next_batch(batch_size)
  62. sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})
  63.  
  64. test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
  65. train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
  66. print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))

结果如下

  1. Iter 0,Testing Accuracy 0.913,Training Accuracy 0.909146
  2. Iter 1,Testing Accuracy 0.9318,Training Accuracy 0.927218
  3. Iter 2,Testing Accuracy 0.9397,Training Accuracy 0.9362
  4. Iter 3,Testing Accuracy 0.943,Training Accuracy 0.940637
  5. Iter 4,Testing Accuracy 0.9449,Training Accuracy 0.945746
  6. Iter 5,Testing Accuracy 0.9489,Training Accuracy 0.949491
  7. Iter 6,Testing Accuracy 0.9505,Training Accuracy 0.9522
  8. Iter 7,Testing Accuracy 0.9542,Training Accuracy 0.956
  9. Iter 8,Testing Accuracy 0.9543,Training Accuracy 0.957782
  10. Iter 9,Testing Accuracy 0.954,Training Accuracy 0.959
  11. Iter 10,Testing Accuracy 0.9558,Training Accuracy 0.959582
  12. Iter 11,Testing Accuracy 0.9594,Training Accuracy 0.963146
  13. Iter 12,Testing Accuracy 0.959,Training Accuracy 0.963746
  14. Iter 13,Testing Accuracy 0.961,Training Accuracy 0.964764
  15. Iter 14,Testing Accuracy 0.9605,Training Accuracy 0.9658
  16. Iter 15,Testing Accuracy 0.9635,Training Accuracy 0.967528
  17. Iter 16,Testing Accuracy 0.9639,Training Accuracy 0.968582
  18. Iter 17,Testing Accuracy 0.9644,Training Accuracy 0.969309
  19. Iter 18,Testing Accuracy 0.9651,Training Accuracy 0.969564
  20. Iter 19,Testing Accuracy 0.9664,Training Accuracy 0.971073
  21. Iter 20,Testing Accuracy 0.9654,Training Accuracy 0.971746
  22. Iter 21,Testing Accuracy 0.9664,Training Accuracy 0.971764
  23. Iter 22,Testing Accuracy 0.9682,Training Accuracy 0.973128
  24. Iter 23,Testing Accuracy 0.9679,Training Accuracy 0.973346
  25. Iter 24,Testing Accuracy 0.9681,Training Accuracy 0.975164
  26. Iter 25,Testing Accuracy 0.969,Training Accuracy 0.9754
  27. Iter 26,Testing Accuracy 0.9706,Training Accuracy 0.975764
  28. Iter 27,Testing Accuracy 0.9694,Training Accuracy 0.975837
  29. Iter 28,Testing Accuracy 0.9703,Training Accuracy 0.977109
  30. Iter 29,Testing Accuracy 0.97,Training Accuracy 0.976946
  31. Iter 30,Testing Accuracy 0.9715,Training Accuracy 0.977491
  1. Testing AccuracyTraining Accuracy之间的差距为0.005991
    dropout值设置为1的时候,
  1. Iter 0,Testing Accuracy 0.9471,Training Accuracy 0.955037
  2. Iter 1,Testing Accuracy 0.9597,Training Accuracy 0.9738
  3. Iter 2,Testing Accuracy 0.9616,Training Accuracy 0.980928
  4. Iter 3,Testing Accuracy 0.9661,Training Accuracy 0.985091
  5. Iter 4,Testing Accuracy 0.9674,Training Accuracy 0.987709
  6. Iter 5,Testing Accuracy 0.9692,Training Accuracy 0.989255
  7. Iter 6,Testing Accuracy 0.9692,Training Accuracy 0.990146
  8. Iter 7,Testing Accuracy 0.9708,Training Accuracy 0.991182
  9. Iter 8,Testing Accuracy 0.9711,Training Accuracy 0.991982
  10. Iter 9,Testing Accuracy 0.9712,Training Accuracy 0.9924
  11. Iter 10,Testing Accuracy 0.971,Training Accuracy 0.992691
  12. Iter 11,Testing Accuracy 0.9706,Training Accuracy 0.993055
  13. Iter 12,Testing Accuracy 0.971,Training Accuracy 0.993309
  14. Iter 13,Testing Accuracy 0.9717,Training Accuracy 0.993528
  15. Iter 14,Testing Accuracy 0.9719,Training Accuracy 0.993764
  16. Iter 15,Testing Accuracy 0.9715,Training Accuracy 0.993927
  17. Iter 16,Testing Accuracy 0.9715,Training Accuracy 0.994091
  18. Iter 17,Testing Accuracy 0.9714,Training Accuracy 0.994291
  19. Iter 18,Testing Accuracy 0.9719,Training Accuracy 0.9944
  20. Iter 19,Testing Accuracy 0.9719,Training Accuracy 0.994564
  21. Iter 20,Testing Accuracy 0.9722,Training Accuracy 0.994673
  22. Iter 21,Testing Accuracy 0.9725,Training Accuracy 0.994855
  23. Iter 22,Testing Accuracy 0.9731,Training Accuracy 0.994891
  24. Iter 23,Testing Accuracy 0.9721,Training Accuracy 0.994928
  25. Iter 24,Testing Accuracy 0.9722,Training Accuracy 0.995018
  26. Iter 25,Testing Accuracy 0.9725,Training Accuracy 0.995109
  27. Iter 26,Testing Accuracy 0.9729,Training Accuracy 0.9952
  28. Iter 27,Testing Accuracy 0.9726,Training Accuracy 0.995255
  29. Iter 28,Testing Accuracy 0.9725,Training Accuracy 0.995327
  30. Iter 29,Testing Accuracy 0.9725,Training Accuracy 0.995364
  31. Iter 30,Testing Accuracy 0.9722,Training Accuracy 0.995437
  1. Testing AccuracyTraining Accuracy之间的差距为0.23237,本次实验中只有60000个样本,当样本量到达几百万的时候,这个差距值会更大,也就是训练出的模型在训练数据集中效果非常好,几乎满足了任意一个样本,但是在测试数据集中效果却很差,此时就是典型的过拟合现象。
    所以一般稍微复杂的网络中都会加入dropout,防止过拟合。
  1.  

Tensorflow学习教程------过拟合的更多相关文章

  1. Tensorflow学习教程------代价函数

    Tensorflow学习教程------代价函数   二次代价函数(quadratic cost): 其中,C表示代价函数,x表示样本,y表示实际值,a表示输出值,n表示样本的总数.为简单起见,使用一 ...

  2. Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例

    紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取.建立网络以及模型训练整理成一个小样例,完整代码如下. #coding:utf-8 import t ...

  3. tensorflow 学习教程

    tensorflow 学习手册 tensorflow 学习手册1:https://cloud.tencent.com/developer/section/1475687 tensorflow 学习手册 ...

  4. Tensorflow学习教程------创建图启动图

    Tensorflow作为目前最热门的机器学习框架之一,受到了工业界和学界的热门追捧.以下几章教程将记录本人学习tensorflow的一些过程. 在tensorflow这个框架里,可以讲是若数据类型,也 ...

  5. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  6. tensorflow学习2-线性拟合和神经网路拟合

    线性拟合的思路: 线性拟合代码: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt #%%图形绘制 ...

  7. Tensorflow学习教程------非线性回归

    自己搭建神经网络求解非线性回归系数 代码 #coding:utf-8 import tensorflow as tf import numpy as np import matplotlib.pypl ...

  8. Tensorflow学习教程------利用卷积神经网络对mnist数据集进行分类_利用训练好的模型进行分类

    #coding:utf-8 import tensorflow as tf from PIL import Image,ImageFilter from tensorflow.examples.tut ...

  9. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

随机推荐

  1. java.lang.ClassNotFoundException: com.microsoft.sqlserver.jdbc.SQLServerDriver java.sql.SQLException

    今天下午一直想用netbeans连接数据库,结果就是来来回回碰到这两个问题. 我还在想,连接数据库并不是一个什么困难的事情啊,我都按照教程上一步一步做的,代码什么的都感觉很好,怎么就找不到类呢,怎么就 ...

  2. 第1课 VMware的NSX全面落地软件定义网络SDN

    SDN的定义: 即软件定义网络(Software Defined Network)的缩写,它是一种基于网络架构的创新,一种在已存在物理传输网络之上的抽象形态,它是一种体系结构,它是众多网络虚拟化技术中 ...

  3. elasticsearch-java客户端测试

    1.环境准备 (1)添加依赖 <dependency> <groupId>org.elasticsearch.client</groupId> <artifa ...

  4. 从三星官方uboot开始移植

    移植前的准备 下载 android_uboot_smdkv210.tar.bz2 这个文件 开始移植 本人使用的开发板是九鼎的 x210,在三星 uboot 的主 Makefile 中找到了类似的 s ...

  5. Koa原理和封装

    相关文章 最基础 实现一个简单的koa2框架 实现一个简版koa koa实践及其手撸 Koa源码只有4个js文件 application.js:简单封装http.createServer()并整合co ...

  6. Python入门必学知识,30万年薪Python工程师带你学

    Python是一种计算机编程语言.计算机编程语言和我们日常使用的自然语言有所不同,最大的区别就是,自然语言在不同的语境下有不同的理解,而计算机要根据编程语言执行任务,就必须保证编程语言写出的程序决不能 ...

  7. CSS样式表——样式2

    样式 5)边界边框 margin:0px;                                            //外边距为0 margin:10px 0px 0px 10px;   ...

  8. Golang的选择结构-switch语句

    Golang的选择结构-switch语句 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.选择语句应用场景概述 选择结构也称为条件判断,生活中关于判断的场景也非常的多,比如: ( ...

  9. UVA - 12166 Equilibrium Mobile (修改天平)(dfs字符串表示的二叉树)

    题意:问使天平平衡需要改动的最少的叶子结点重量的个数. 分析:天平达到平衡总会有个重量,这个重量可以由某个叶子结点的重量和深度直接决定. 如下例子: 假设根结点深度为0,结点6深度为1,若以该结点为基 ...

  10. 安装npm install时,长时间停留在fetchMetadata: sill 解决方法——换npm的源

    安装npm install时,长时间停留在fetchMetadata: sill mapToRegistry uri http://registry.npmjs.org/whatwg-fetch处, ...