04Dropout
不加Dropout,训练数据的准确率高,基本上可以接近100%,但是,对于测试集来说,效果并不好;
加上Dropout,训练数据的准确率可能变低,但是,对于测试集来说,效果更好了,所以说Dropout可以防止过拟合。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) # 创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1))
b1 = tf.Variable(tf.zeros([2000]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob) W2 = tf.Variable(tf.truncated_normal([2000, 2000], stddev=0.1))
b2 = tf.Variable(tf.zeros([2000]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob) W3 = tf.Variable(tf.truncated_normal([2000, 1000], stddev=0.1))
b3 = tf.Variable(tf.zeros([1000]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)
L3_drop = tf.nn.dropout(L3, keep_prob) W4 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.1))
b4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, W4) + b4) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) #argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7}) test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) + ",Training Accuracy " + str(train_acc))

04Dropout的更多相关文章
随机推荐
- nginx 缓存设置
浏览器缓存原理 浏览器缓存 HTTP协议定义的缓存机制(如:Expires:Cache-control等) 2.浏览器无缓存 3.客户端有缓存 校验过期机制 校验是否过期 ...
- nginx 静态资源WEB服务
1.静态资源类型 非服务器动态运行生成的文件 类型种类 浏览器端渲染 HTML.CSS.JS 图片 JPEG.GIF.PNG 视频 FLV.MPEG ...
- DC.p4: programming the forwarding plane of a data-center switch
Name of article:Dc. p4: Programming the forwarding plane of a data-center switch Origin of the artic ...
- HDU 6438 Buy and Resell
高卖低买,可以交易多次 维护一个优先队列,贪心 相当于每天卖出 用当前元素减优先队列最小得到收益 用0/卖出,1/买入标志是否真实进行了交易,记录次数 #include<bits/stdc++. ...
- php正则匹配汉字提取其它信息剔除和验证邮箱
正则匹配汉字提取其它信息剔除demo <?php //提取字符串中的汉字其余信息剔除 $str='te,st 测 .试,.,.?!::·…~&@#,.?!:;.……-&@#“” ...
- db2用户权限赋值
<!------------创建db2用户和组-------------------------------------------> [root@localhost ~]# userad ...
- Vultr CentOS下后台跑node
在Mac或者Windows下简直易如反掌.几行命令搞定的事情,但因为使用的是远程SSH连接纯命令行处理,所以需要记录下来怎么弄. 比如, 1. 怎么在什么都没有的CentOS里下载Node安装包? 2 ...
- java通过HtmlUnit工具和J4L实现模拟带验证码登录
1.HtmlUnit 1.1介绍 HtmlUnit是一个用java编写的无界面浏览器,建模html文档,通过API调用页面,填充表单,点击链接等等.如同正常浏览器一样操作.典型应用于测试以及从网页抓取 ...
- 安装element-ui
element地址:https://element.eleme.cn/2.0/#/zh-CN/component/quickstart 1.在新建终端 [安装element-ui组件依赖]cnpm i ...
- python实现建立soap通信(调用及测试webservice接口)
实现代码如下: #调用及测试webservice接口 import requests class SoapConnect: def get_soap(self,url,data): r = reque ...