不加Dropout,训练数据的准确率高,基本上可以接近100%,但是,对于测试集来说,效果并不好;

加上Dropout,训练数据的准确率可能变低,但是,对于测试集来说,效果更好了,所以说Dropout可以防止过拟合。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) # 创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1))
b1 = tf.Variable(tf.zeros([2000]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob) W2 = tf.Variable(tf.truncated_normal([2000, 2000], stddev=0.1))
b2 = tf.Variable(tf.zeros([2000]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob) W3 = tf.Variable(tf.truncated_normal([2000, 1000], stddev=0.1))
b3 = tf.Variable(tf.zeros([1000]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)
L3_drop = tf.nn.dropout(L3, keep_prob) W4 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.1))
b4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, W4) + b4) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) #argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7}) test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) + ",Training Accuracy " + str(train_acc))

04Dropout的更多相关文章

随机推荐

  1. window 连接服务器工具

    Xshell xftp 下载网址 以上两个软件均免费, 只需要邮件激活即可. 其中 xshell 主要用来连接服务器,方便使用命令行.xftp 方便传输文件.

  2. 实用工具/API

    实用工具/API PNG图片无损压缩 在线给图片加水印 随机密码生成 随机头像生成 微博一键清理工具 CSS压缩 在线工具 免费虚拟主机 技术摘要 https://github.com/biezhi/ ...

  3. E. Compress Words

    E. Compress Words KMP #include<bits/stdc++.h> using namespace std; ]; int len; void getNext(ch ...

  4. Oracle-存储过程实现更改用户密码

    --调用存储过程实现更改DB用户密码 CREATE OR REPLACE PROCEDURE MODUSERPW(USER_NAME VARCHAR2,USER_PW VARCHAR2)ISSQLTX ...

  5. [CSP-S模拟测试]:老司机的狂欢(LIS+LCA)

    题目背景 光阴荏苒.不过,两个人还在,两支车队还在,熟悉的道路.熟悉的风景,也都还在.只是,这一次,没有了你死我活的博弈,似乎和谐了许多.然而在机房是不允许游戏的,所以班长$XZY$对游戏界面进行了降 ...

  6. Qt之zip压缩/解压缩(QuaZIP)

    摘要: 简述 QuaZIP是使用Qt/C++对ZLIB进行简单封装的用于压缩及解压缩ZIP的开源库.适用于多种平台,利用它可以很方便的将单个或多个文件打包为zip文件,且打包后的zip文件可以通过其它 ...

  7. SQL数据库字段添加说明文字

    1.查看指定表中的所有带说明文字的字段内容 SELECT *,OBJECT_NAME(major_id) AS obj_name FROM sys.extended_properties WHERE ...

  8. Understanding ECMAScript 6 阅读问题小记

    拖了一年说要看这本书,一直都没坚持下来,开个 bo 记录下觉得疑惑的问题,也算鞭策一下自己. 第一章 块级绑定 1. 第一章“块级绑定”下,说 const 变量如果绑定的是对象 Object,那么修改 ...

  9. 操作Redis--hash/key-value

    redis也是一个数据库,它的存储以key-value的方式存放,比如: a.关系型数据库 比如: mysql.oracle.sql server.db2.sqlite数据库,为关系型数据库 数据通过 ...

  10. 【ABAP系列】SAP 读取生产订单 记入文档的货物移动明细

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[ABAP系列]SAP 读取生产订单 记入文档的 ...