Tensorflow学习笔记(一):MNIST机器学习入门
学习深度学习,首先从深度学习的入门MNIST入手。通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念。
一 MNIST数据集
MNIST是入门级的计算机视觉数据集,包含了各种手写数字的图片。在这个例子中就是通过机器学习训练一个模型,以识别图片中的数字。
MNIST数据集来自 http://yann.lecun.com/exdb/mnist/
Tensorflow提供了一份python代码用于自动下载安装数据集。Tensorflow官方文档中的url打不开,在CSDN上找到了一个分享:http://download.csdn.net/detail/u010417185/9588647
和官方有点不同的是,我直接把四个数据集下载下来,放在/tmp/mnist下,在项目文件中使用以下代码导入:
- import input_data
- import tensorflow as tf
- mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)
这里的数据集分为两个部分:60000的训练数据集(mnist.train)和10000的测试数据集(mnist.test),测试集的作用是帮助模型泛化。数据对应包含图片和标签,分别用mnist.train.images,mnist.train.lables,mnist.test.images,mnist.test.lables来表示。每张图片有28×28=784个像素点,因此训练图片mnist.train.images的张量表示为 [60000, 784]
,第一个纬度用于索引图片,第二纬度用于索引像素点。由于判断10个数字,这里采用热独,即one-hot-vectors,除了一位数字为1外其他纬度数字为0。例如判断数字为0则其表示为[
1,0,0,0,0,0,0,0,0,0]
。因此训练标签表示为[10000,10
]
,第一纬度索引图片,第二纬度判断数字。
二 softmax回归介绍
softmax模型可以给不同的对象分配概率。根据下图,对输入的x的加权求和,再分别加上一个偏置量,最后输入到softmax函数中:
具体转换为公式,即:
三 实现回归模型
首先进行模型的定义,如下:
- x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
- W = tf.Variable(tf.zeros([784,10])) #初始化权值
- b = tf.Variable(tf.zeros([10])) #初始化偏置值
- y = tf.nn.softmax(tf.matmul(x,W) + b) #根据公式计算
四 训练模型
选用的损失函数为交叉熵,其定义如下:
其中y为预测的概率分布,y'为实际分布。
代码如下:
- y_ = tf.placeholder("float", [None,10]) #表示实际的分布
- cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #计算损失函数
- train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #以梯度下降算法最小化损失函数
- init = tf.initialize_all_variables() #初始化所有变量
- sess = tf.Session() #定义会话
- sess.run(init) #初始化会话
- for i in range(1000): #开始训练,循环训练1000次
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
五 评估模型
选用tf.argmax
函数评估,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)
返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1)
代表正确的标签,用 tf.equal
来检测预测是否与真实标签匹配(索引位置一样表示匹配)。
代码如下:
- correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) #评估
- accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float")) #将结果转换为浮点数
- print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) #输出
六 代码
- import input_data
- import tensorflow as tf
- mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)
- x = tf.placeholder(tf.float32, [None, 784]) #使用占位符placeholder,第一维度可指定图片的数量是任意的
- W = tf.Variable(tf.zeros([784,10])) #初始化权值
- b = tf.Variable(tf.zeros([10])) #初始化偏置值
- y = tf.nn.softmax(tf.matmul(x,W) + b) #根据公式计算
- y_ = tf.placeholder("float", [None,10]) #表示实际的分布
- cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #计算损失函数
- train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #以梯度下降算法最小化损失函数
- init = tf.initialize_all_variables() #初始化所有变量
- sess = tf.Session() #定义会话
- sess.run(init) #初始化会话
- for i in range(1000): #开始训练,循环训练1000次
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
- correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) #评估
- accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float")) #将结果转换为浮点数
- print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) #输出
七 实验结果
最终测试结果精确度在91%左右。
Tensorflow学习笔记(一):MNIST机器学习入门的更多相关文章
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- TensorFlow框架(3)之MNIST机器学习入门
1. MNIST数据集 1.1 概述 Tensorflow框架载tensorflow.contrib.learn.python.learn.datasets包中提供多个机器学习的数据集.本节介绍的是M ...
- TensorFlow学习笔记(MNIST报错修正 适用Tensorflow1.3)
在Tensorflow实战Google框架下的深度学习这本书的MNIST的图像识别例子中,每次都要报错 错误如下: Only call `sparse_softmax_cross_entropy_ ...
- tensorflow学习笔记————分类MNIST数据集
在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题. 一般是通过使用tensorflow内置的函数进行下载和加载, from tensorflow.examp ...
- tensorflow学习笔记(10) mnist格式数据转换为TFrecords
本程序 (1)mnist的图片转换成TFrecords格式 (2) 读取TFrecords格式 # coding:utf-8 # 将MNIST输入数据转化为TFRecord的格式 # http://b ...
- MNIST机器学习入门【学习笔记】
平台信息:PC:ubuntu18.04.i5.anaconda2.cuda9.0.cudnn7.0.5.tensorflow1.10.GTX1060 作者:庄泽彬(欢迎转载,请注明作者) 说明:本文是 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记二:入门基础 好教程 可用
http://www.cnblogs.com/denny402/p/5852083.html tensorflow学习笔记二:入门基础 TensorFlow用张量这种数据结构来表示所有的数据.用一 ...
随机推荐
- NFS(Network File System)服务配置和使用
Sun公司开发NFS (Network File System)之初就是为了在不同linux/Unix系统之间共享文件或者文件夹.可以在本地通过网络挂载远程主机的共享文件,和远程主机交互.NFS共享存 ...
- uva 699 the falling leaves——yhx
aaarticlea/png;base64,iVBORw0KGgoAAAANSUhEUgAAA3QAAAMsCAIAAACTL3d2AAAgAElEQVR4nOx9y7GuPA4tKRADk/92T8 ...
- 【ASP.NET 进阶】无刷新上传图片之一:利用一般处理程序
效果图: 源代码地址:https://github.com/YeXiaoChao/UploadThePic
- 分布式服务框架 Zookeeper(转)
分布式服务框架 Zookeeper -- 管理分布式环境中的数据 Zookeeper 分布式服务框架是 Apache Hadoop 的一个子项目,它主要是用来解决分布式应用中经常遇到的一些数据管理问题 ...
- ajax请求json数据案例
今天有这样一个需求,点击六个大洲,出现对应的一些请求信息,展示在下面,请求请求过后,第二次点击就无需请求.如图所示:点击北美洲下面出现请求的一些数据 html代码结构: <div class=& ...
- UVALive 6257 Chemist's vows --一道题的三种解法(模拟,DFS,DP)
题意:给一个元素周期表的元素符号(114种),再给一个串,问这个串能否有这些元素符号组成(全为小写). 解法1:动态规划 定义:dp[i]表示到 i 这个字符为止,能否有元素周期表里的符号构成. 则有 ...
- POJ 1984 Navigation Nightmare
并查集,给n个点和m条边,每条边有方向和长度,再给q个询问,第i个询问查询两个点之间在Ti时刻时的曼哈顿距离(能连通则输出曼哈顿距离,否则输出-1) 这题跟Corporative Network 有点 ...
- JavaScript Promise API
同步编程通常来说易于调试和维护,然而,异步编程通常能获得更好的性能和更大的灵活性.异步的最大特点是无需等待."Promises"渐渐成为JavaScript里最重要的一部分,大量的 ...
- java 20 - 9 带有缓冲区的字节输出流和字节输入流
由之前字节输入的两个方式,我们可以发现,通过定义数组读取数组的方式比一个个字节读取的方式快得多. 所以,java就专门提供了带有缓冲区的字节类: 缓冲区类(高效类) 写数据:BufferedOutpu ...
- PCTF-2016-WEB
Pctf ** web100 PORT51** 开始看到这个真的无法下手,想过用python–socket编程或者scapy发包.自己觉得是可以的,但是没有去试,后面看一大神writeup,知道: ...