TensorFlow学习笔记3-从MNIST开始
TensorFlow学习笔记3-从MNIST开始学习softmax
本笔记内容为“从MNIST学习softmax regression算法的实现”。
注意:由于我学习机器学习及之前的书写习惯,约定如下:
\(X\)表示训练集的设计矩阵,其大小为m行n列,m表示训练集的大小(size),n表示特征的个数;
\(W\)表示权重矩阵,其大小是n行k列,n为输入特征的个数,k为输出(特征)的个数;
\(\boldsymbol{y}\)表示训练集对应标签,其大小为m行,m表示训练集的大小(size);
\(\boldsymbol{y’}\)表示将测试向量\(x\)输入后得到的测试结果;
总之:
注意区分这里的:\(\boldsymbol{y'}=XW+\boldsymbol{b}\) 表示矩阵形式的预测结果(\(\boldsymbol{y’}\)和\(\boldsymbol{b}\)是向量);
之前机器学习中的是(如《机器学习实战》中SVM一章):$y’=\omega^T x+b $ 表示向量形式的预测结果(\(y'\)和\(b\)是标量);
算法部分:包括预测模型和优化目标
以手写输入MNIST为例:
预测模型
\]
其中softmax函数是归一化函数:
\]
其中\(i , j\)的范围为1~10。softmax函数将\(\boldsymbol{z}\)归一化之后变为\(\boldsymbol{y’}\)(预测值)。如下图。
- 训练集:共55000条数据,每条数据中有784个特征(将28*28个像素点进行展开,忽略了像素间的结构关系),矩阵中m=55000,n=784;
- 参数\(W\)中的元素\(W_{i,j}\)的含义是:第i个像素点在数字j中占的权重,意思是如果很多数字j的实例中都有i,说明像素点i很大可能代表数字j,那么其权重会很大。
- 参数\(b\)中的元素\(b_{i,j}\)的含义是:第i个像素点在数字j的偏置量,意思是如果大部分数字都是0,则0的特征对应的bias值会很大。
优化目标:交叉熵的最小化
交叉熵:
\]
其中,
每个batch中的所有预测项的交叉熵的平均值为评价指标。
实现部分:
用随机梯度下降优化器对评价指标进行优化。
每次随机选取训练集中的100个子集作为batch(桶)进行训练,共训练1000次。
预测模型的评价
统计准确率。
附代码:
import tensorflow as tf
# 1 Collect data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True);
print(mnist.train.images.shape, mnist.train.labels.shape);
print(mnist.test.images.shape,mnist.test.labels.shape);
print(mnist.validation.images.shape,mnist.validation.labels.shape);
# 2 Create Model
X = tf.placeholder(tf.float32,[None,784]);
y = tf.placeholder(tf.float32,[None,10]);
W = tf.Variable(tf.random_uniform([784,10],-1,1));
b = tf.Variable(tf.zeros([10]));
z = tf.matmul(X,W)+b;
y_ = tf.nn.softmax(z);
# 3 loss function
loss = -tf.reduce_mean(tf.reduce_sum(y*tf.log(y_),axis=1));
optimizer = tf.train.GradientDescentOptimizer(0.5);
train = optimizer.minimize(loss);
# 4 initialzer
init = tf.initialize_all_variables();
sess = tf.InteractiveSession();
sess.run(init);
# 5 Train
for step in range(1000):
x_batch,y_batch = mnist.train.next_batch(100);
sess.run(train,feed_dict={X:x_batch,y:y_batch});
if step%10 ==0:
print(step/10,"%",sess.run(loss,feed_dict={X:x_batch,y:y_batch}));
# 6 Output
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1));
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32));
print(accuracy.eval({X:mnist.test.images,y:mnist.test.labels}));
sess.close();
更进一步
- 使用
InteractiveSession
将这个session注册为默认的session,之后的运算都默认跑在这个session里,不同session之间的运算与数据相互独立。
比较
batch_xs, batch_ys = mnist.train.next_batch(100) # 使用minibatch,一个batch大小为100
train_step.run({x: batch_xs, y: batch_ys})
与
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
的异同。
本质没有区别:也就是说只要是字典dict形式的写法,就是输入;否则就是输出。
TensorFlow学习笔记3-从MNIST开始的更多相关文章
- tensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试
刚开始学习tf时,我们从简单的地方开始.卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始. 神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输 ...
- Tensorflow学习笔记(对MNIST经典例程的)的代码注释与理解
1 #coding:utf-8 # 日期 2017年9月4日 环境 Python 3.5 TensorFlow 1.3 win10开发环境. import tensorflow as tf from ...
- tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)
mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试
http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试 刚开始学习tf时,我们从 ...
随机推荐
- framebuffer设备驱动分析
一.设备驱动相关文件 1.1. 驱动框架相关文件 1.1.1. drivers/video/fbmem.c a. 创建graphics类.注册FB的字符设备驱动 fbmem_init(void) { ...
- CentOS 5.5编译安装lnmp
如果是安装Centos6.5记得Perl是必选的,否则无法安装VMWare Tools!!!!切记 如果出现make错误需要安装其他软件,装好后 make clean make install ...
- 移动端抓包工具——Fiddler(一)
web端抓包一般利用浏览器自带的检查功能(F12),然后看Network项,根据请求响应判断出问题 移动端怎么抓包呢,这里介绍一款常用的抓包工具——Fiddler 前提: 1.必须确保安装fiddl ...
- 第一个chrome extension
如今,chrome浏览器的使用如越来越流行,chrome extension往往能提供更多很丰富的功能.以前一直想了解这方面的东西,可是又担心很复杂.前段时间,在斗鱼看一个直播,想刷弹幕,但是每次自己 ...
- VMware新加网卡NAT连接(内网)出现本机与虚拟机ping不通的问题
今新加网卡NAT连接,配置好之后始终出现eth1:link is not ready. 虚拟机与本机不能建立连接. 解决方案:windows里面打开服务开启VMware NAT Service,并关闭 ...
- 在虚拟机Linux中安装VMTools遇到的问题-小结
总结: 遇到的问题:No support for locale: zh_CN.utf8 可能的解决方法:1.sudo dpkg-reconfigure locale (重新配置?) 2.上一步失败,提 ...
- 吉首大学2019年程序设计竞赛(重现赛)I 滑稽树上滑稽果 (莫队+逆元打表)
链接:https://ac.nowcoder.com/acm/contest/992/I来源:牛客网 时间限制:C/C++ 1秒,其他语言2秒空间限制:C/C++ 32768K,其他语言65536K ...
- Solr从数据库导入数据(DIH)
一. 数据导入(DataImportHandler-DIH) DIH 是solr 提供的一种针对数据库.xml/HTTP.富文本对象导入到solr 索引库的工具包.这里只针对数据库做介绍. A.准备以 ...
- tpcc-mysql测试mysql5.6 (EXT4文件系统)
操作系统版本:CentOS release 6.5 (Final) 2.6.32-431.el6.x86_64 #1 内存:32G CPU:Intel(R) Xeon(R) CPU E5-2450 ...
- Typescript + TSLint + webpack 搭建 Typescript 的开发环境
(1)初始化项目 新建一个文件夹“client-side”,作为项目根目录,进入这个文件夹: 我们先使用 npm 初始化这个项目: 这时我们看到了在根目录下已经创建了一个 package.json 文 ...