import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次的大小
batch_size = 64
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义三个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob=tf.placeholder(tf.float32) # 784-1000-500-10
W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
b1 = tf.Variable(tf.zeros([1000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob) W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
b2 = tf.Variable(tf.zeros([500])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob) W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3) #交叉熵
loss = tf.losses.softmax_cross_entropy(y,prediction)
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5}) test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547
Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636
Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182
Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365
Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273
Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454
Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909
Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502
Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366
Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725
Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726
Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909
Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182
Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274
Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546
Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727
Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094
Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635
Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909
Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652
Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274
Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546
Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364
Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456
Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091
Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818
Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702
Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909
Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818
Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364
Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454

8.Dropout的更多相关文章

  1. 在RNN中使用Dropout

    dropout在前向神经网络中效果很好,但是不能直接用于RNN,因为RNN中的循环会放大噪声,扰乱它自己的学习.那么如何让它适用于RNN,就是只将它应用于一些特定的RNN连接上.   LSTM的长期记 ...

  2. Deep Learning 23:dropout理解_之读论文“Improving neural networks by preventing co-adaptation of feature detectors”

    理论知识:Deep learning:四十一(Dropout简单理解).深度学习(二十二)Dropout浅层理解与实现.“Improving neural networks by preventing ...

  3. 正则化方法:L1和L2 regularization、数据集扩增、dropout

    正则化方法:防止过拟合,提高泛化能力 在训练数据不够多时,或者overtraining时,常常会导致overfitting(过拟合).其直观的表现如下图所示,随着训练过程的进行,模型复杂度增加,在tr ...

  4. 深度学习(dropout)

    other_techniques_for_regularization 随手翻译,略作参考,禁止转载 www.cnblogs.com/santian/p/5457412.html Dropout: D ...

  5. Deep learning:四十一(Dropout简单理解)

    前言 训练神经网络模型时,如果训练样本较少,为了防止模型过拟合,Dropout可以作为一种trikc供选择.Dropout是hintion最近2年提出的,源于其文章Improving neural n ...

  6. 简单理解dropout

    dropout是CNN(卷积神经网络)中的一个trick,能防止过拟合. 关于dropout的详细内容,还是看论文原文好了: Hinton, G. E., et al. (2012). "I ...

  7. [转]理解dropout

    理解dropout 原文地址:http://blog.csdn.net/stdcoutzyx/article/details/49022443     理解dropout 注意:图片都在github上 ...

  8. [CS231n-CNN] Training Neural Networks Part 1 : parameter updates, ensembles, dropout

    课程主页:http://cs231n.stanford.edu/ ___________________________________________________________________ ...

  9. 正则化,数据集扩增,Dropout

    正则化方法:防止过拟合,提高泛化能力 在训练数据不够多时,或者overtraining时,常常会导致overfitting(过拟合).其直观的表现如下图所示,随着训练过程的进行,模型复杂度增加,在tr ...

  10. [Neural Networks] Dropout阅读笔记

    多伦多大学Hinton组 http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf 一.目的 降低overfitting的风险 二.原理 ...

随机推荐

  1. 一台物理机器一个IP配置多个域名多套程序的方法

    1.安装nginx cd /usr/local/ wget http://nginx.org/download/nginx-1.2.8.tar.gz tar -zxvf nginx-1.2.8.tar ...

  2. Linux Swap的那些事

    swap是干嘛的? 在Linux下,SWAP的作用类似Windows系统下的“虚拟内存”.当物理内存不足时,拿出部分硬盘空间当SWAP分区(虚拟成内存)使用,从而解决内存容量不足的情况. SWAP意思 ...

  3. Django 虚拟化环境创建

    A:linux下的方法: 1.用python3.6内置的venv创建名为 typeidea-env虚拟化环境: python3.6 -m venv typeidea-env 2.激活环境: cd ty ...

  4. 华三F100系列、华为USG6300系列防火墙 策略路由配置实例

    策略路由,是一种比基于目标网络进行路由更加灵活的数据包路由转发机制,路由器将通过路由图决定如何对需要路由的数据包进行处理,路由图决定了一个数据包的下一跳转发路由器. 策略路由的应用: 1.可以不仅仅依 ...

  5. CF444A DZY Loves Physics【结论】

    题目传送门 话说这道题不分析样例实在是太亏了...结论题啊... 但是话说回来不知道它是结论题的时候会不会想到猜结论呢...毕竟样例一.二都有些特殊. 观察样例发现选中的子图都只有一条边. 于是猜只有 ...

  6. POJ 3207 【2-SAT入门题 + 强连通分量】

    这道题是我对于2-SAT问题的入门题:http://poj.org/problem?id=3207 一篇非常非常非常好的博客,很详细,认真看一遍差不多可以了解个大概:https://blog.csdn ...

  7. [目标检测] 从 R-CNN 到 Faster R-CNN

    R-CNN 创新点 经典的目标检测算法使用滑动窗法依次判断所有可能的区域,提取人工设定的特征(HOG,SIFT).本文则预先提取一系列较可能是物体的候选区域,之后仅在这些候选区域上用深度网络提取特征, ...

  8. 批量导出docker images 的一个简单方法

    docker images |cut -c - |xargs docker save -o iamges.tar 主要 最大的长度不能超过 18 超过了就得改一下 -c 后面的数据长度 最终效果为: ...

  9. 记一次Sqoop抽数据异常

    1. 环境 Hadoop Sqoop awsEMR 2.8.5 1.4.7 5.26.0 2.错误描述 在使用Sqoop抽取MySQL数据时,使用hdfs作为缓存,s3作为hive的存储地址,命令如下 ...

  10. Elastic Search常用元数据简介

    在ES中,除了定义的index,type,和管理的document外,还有若干的元数据.这些元数据用于记录ES中需要使用的核心数据.在ES中,元数据通常使用下划线’_’开头. 1 查看数据GET /i ...