『TensorFlow』读书笔记_多层感知机
多层感知机
输入->线性变换->Relu激活->线性变换->Softmax分类
多层感知机将mnist的结果提升到了98%左右的水平
知识点
过拟合:采用dropout解决,本质是bagging方法,相当于集成学习,注意dropout训练时设置为0~1的小数,测试时设置为1,不需要关闭节点
学习率难以设定:Adagrad等自适应学习率方法
深层网络梯度弥散:Relu激活取代sigmoid激活,不过输出层仍然使用sigmoid激活
对于ReLU激活函数,常用截断正态分布,避免0梯度和完全对称
对于Softmax分类(也就是sigmoid激活),由于对0附近最敏感,所以采用全0初始权重
代码如下
# Author : Hellcat
# Time : 2017/12/7 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('../../../Mnist_data',one_hot=True)
sess = tf.InteractiveSession() in_units = 784
h1_units = 300 # 对于ReLU激活函数,常用截断正态分布,避免0梯度和完全对称
# 对于Softmax分类(也就是sigmoid激活),由于对0附近最敏感,所以采用全0初始权重
W1 = tf.Variable(tf.truncated_normal([in_units, h1_units],stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units], dtype=tf.float32))
W2 = tf.Variable(tf.zeros([h1_units, 10], dtype=tf.float32))
b2 = tf.Variable(tf.zeros([10], dtype=tf.float32)) x = tf.placeholder(tf.float32, [None, in_units])
y_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) hidden1 = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))
hidden1_drop = tf.nn.dropout(hidden1, keep_prob)
y = tf.nn.softmax(tf.add(tf.matmul(hidden1_drop, W2), b2)) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), axis=1))
# train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y,axis=1), tf.argmax(y_,axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.global_variables_initializer().run()
for i in range(3000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x:batch_xs, y_:batch_ys, keep_prob:0.5})
if i % 100 == 0:
print('当前迭代次数{0},当前准确率{1:.3f}'.
format(i,accuracy.eval({x:batch_xs, y_:batch_ys, keep_prob:1.0})))
print(accuracy.eval({x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))
输出如下,
当前迭代次数0,当前准确率0.350
当前迭代次数100,当前准确率0.950
当前迭代次数200,当前准确率0.960
当前迭代次数300,当前准确率0.940
当前迭代次数400,当前准确率0.940
当前迭代次数500,当前准确率0.980
当前迭代次数600,当前准确率0.990
当前迭代次数700,当前准确率0.990
当前迭代次数800,当前准确率1.000
当前迭代次数900,当前准确率0.970
当前迭代次数1000,当前准确率0.980
当前迭代次数1100,当前准确率0.960
当前迭代次数1200,当前准确率1.000
当前迭代次数1300,当前准确率0.970
当前迭代次数1400,当前准确率0.990
当前迭代次数1500,当前准确率1.000
当前迭代次数1600,当前准确率1.000
当前迭代次数1700,当前准确率1.000
当前迭代次数1800,当前准确率0.980
当前迭代次数1900,当前准确率0.980
当前迭代次数2000,当前准确率1.000
当前迭代次数2100,当前准确率1.000
当前迭代次数2200,当前准确率1.000
当前迭代次数2300,当前准确率0.990
当前迭代次数2400,当前准确率1.000
当前迭代次数2500,当前准确率1.000
当前迭代次数2600,当前准确率0.990
当前迭代次数2700,当前准确率0.980
当前迭代次数2800,当前准确率0.990
当前迭代次数2900,当前准确率0.980
0.9778
有意思的是,使用tf.train.AdagradOptimizer()优化器时偶尔会出错,使用梯度下降优化器之后再修改回来就没问题了,可能是我的解释器出问题了。
『TensorFlow』读书笔记_多层感知机的更多相关文章
- 『TensorFlow』读书笔记_降噪自编码器
『TensorFlow』降噪自编码器设计 之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...
- 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_上
完整项目见:Github 完整项目中最终使用了ResNet进行分类,而卷积版本较本篇中结构为了提升训练效果也略有改动 本节主要介绍进阶的卷积神经网络设计相关,数据读入以及增强在下一节再与介绍 网络相关 ...
- 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_下
数据读取部分实现 文中采用了tensorflow的从文件直接读取数据的方式,逻辑流程如下, 实现如下, # Author : Hellcat # Time : 2017/12/9 import os ...
- 『TensorFlow』读书笔记_简单卷积神经网络
如果你可视化CNN的各层级结构,你会发现里面的每一层神经元的激活态都对应了一种特定的信息,越是底层的,就越接近画面的纹理信息,如同物品的材质. 越是上层的,就越接近实际内容(能说出来是个什么东西的那些 ...
- 『TensorFlow』读书笔记_VGGNet
VGGNet网络介绍 VGG系列结构图, 『cs231n』卷积神经网络工程实践技巧_下 1,全部使用3*3的卷积核和2*2的池化核,通过不断加深网络结构来提升性能. 所有卷积层都是同样大小的filte ...
- 『TensorFlow』读书笔记_ResNet_V2
『PyTorch × TensorFlow』第十七弹_ResNet快速实现 要点 神经网络逐层加深有Degradiation问题,准确率先上升到饱和,再加深会下降,这不是过拟合,是测试集和训练集同时下 ...
- 『TensorFlow』读书笔记_AlexNet
网络结构 创新点 Relu激活函数:效果好于sigmoid,且解决了梯度弥散问题 Dropout层:Alexnet验证了dropout层的效果 重叠的最大池化:此前以平均池化为主,最大池化避免了平均池 ...
- 『TensorFlow』读书笔记_Inception_V3_下
极为庞大的网络结构,不过下一节的ResNet也不小 线性的组成,结构大体如下: 常规卷积部分->Inception模块组1->Inception模块组2->Inception模块组3 ...
- 『TensorFlow』读书笔记_TFRecord学习
一.程序介绍 1.包导入 # Author : Hellcat # Time : 17-12-29 import os import numpy as np np.set_printoptions(t ...
随机推荐
- 深度学习基础(一)LeNet_Gradient-Based Learning Applied to Document Recognition
作者:Yann LeCun,Leon Botton, Yoshua Bengio,and Patrick Haffner 这篇论文内容较多,这里只对部分内容进行记录: 以下是对论文原文的翻译: 在传统 ...
- P5280 [ZJOI2019]线段树
题目链接:洛谷 题目描述:[比较复杂,建议看原题] 这道题太神仙了,线段树上做树形dp. 根据树形dp的套路,都是按照转移的不同情况给节点分类.这里每次modify的时候对于节点的影响也不同,所以我们 ...
- python框架之Django(8)-CBV中添加装饰器
现有如下检查登录装饰器: from functools import wraps def check_login(func): @wraps(func) def inner(request, *arg ...
- 变量存储缓存机制 Number (int bool float complex)
# ###变量存储的缓存机制(为了节省空间) #Number (int bool float complex) # (1) int -5~正无穷范围内 var1 = 18 var2 = 18 var1 ...
- ubuntu12.04下编译Linux tina 2.1/android经验
用的是osboxes下的vdi. 编译Linux 1. 不能在root用户下操作 2. 执行 make kernel_menuconfig 报错,需要 apt-get install zlib1g z ...
- Unicode编码与中文互转
/** * unicode编码转换为汉字 * @param unicodeStr 待转化的编码 * @return 返回转化后的汉子 */ public static String UnicodeTo ...
- Mysql授权root用户远程登录
默认情况下Mysql的root用户不支持远程登录,使用以下命令授权 [Charles@localhost ~]$ mysql -uroot -p123 MariaDB [(none)]> u ...
- Ngnix 配置文件
配置文件路径/usr/local/nginx/conf/nginx.conf user www www; #nginx 服务的伪用户和用户组 worker_processes auto; #启动进程, ...
- Practical Lessons from Predicting Clicks on Ads at Facebook
ABSTRACT 这篇paper中作者结合GBDT和LR,取得了很好的效果,比单个模型的效果高出3%.随后作者研究了对整体预测系统产生影响的几个因素,发现Feature(能挖掘出用户和广告的历史信息) ...
- 第一篇——Struts2的工作原理及HelloWorld简单实现
Struts2工作原理: 一个请求在Struts框架中的处理步骤: 1.客户端初始化一个指向Servlet容器(例如Tomcat)的请求: 2.这个请求经过一系列的过滤器(Filter): 3.接着F ...