MNIST:一个由60000行训练数据集和10000行的测试数据集(机器学习模型设计时必须有一个单独的数据集用于评估模型的性能)组成的数据集。

下载mnist的数据集后,将文件放入C:\Users\missouter.keras\datasets中,再将以下代码写入程序:

  1. import tensorflow as tf
  2. import numpy as np
  3. def transform_data(x_train, y_train, x_test, y_test):
  4. x_train_new=x_train.reshape(-,*)
  5. x_test_new=x_test.shape(-, *)
  6. y_train_new=np.zeros((y_train.shape[],))
  7. y_test_new=np.zeros((y_test.shape[],))
  8. for i in range(y_train.shape[]):
  9. y_train_new[i][y_train[i]] =
  10. for i in range(y_test.shape[]):
  11. y_test_new[i][y_test[i]] =
  12. return x_train_new,y_train_new,x_test_new,y_test_new
  13. mnist=tf.keras.datasets.mnist
  14. (x_train, y_train), (x_test, y_test)=mnist.load_data()
  15. x_train,x_test=x_train/255.0,x_test/255.0
  16. x_train, y_train, x_test, y_test = transform_data(x_train, y_train, x_test, y_test)

这样就直接完成了数据集的导入与训练、测试集的分割。处理完后,还需定义每轮feed给网络的数量:

  1. BATCH_SIZE = 100
  2. TOTAL_SIZE = x_train.shape[0]

下一步需要使用softmax回归(softmax regression),softmax模型可以用来给不同的对象分配概率,以得到某个图像为某个特定数字类的证据为例,输入图片x代表的数字为x的证据可表示为:

  1. evidencei=∑Wijxj + bi

其中Wi代表权重;因输入会带一些无关的干扰量,故加入偏置量bi,j代表给定图片的像素索引。

运用softmax函数可将证据转换成概率y:

  1. y=softmaxevidence 激励 / 链接函数

从而将我们定义的线性函数输出转换成十个数字类的概率分布;一般将softmax模型函数定义为下列的前一种形式,即将输入值当成幂指数求值,再将结果值正则化。

  1. softmaxx)=normalizeexpx))
  2. softmaxxi=expxi)/∑expxj

更为紧凑的格式可写为:

  1. y=softmaxWx+b

用以下方式创建可交换单元:

  1. x=tf.placeholder("float", [None, 784])

placeholder构筑整个系统的graph;第一个参数是数据类型,第二个参数是数据形状,上述代码中[None, 784]表明行未定,列为784列;第三个参数为名称,未写入代码。

上述语句的x是一个占位符;MNIST的一张图包含28*28=784个像素,每一张图展平成784维的向量;none 表示此张量的第一维度可以是任何一个长度。

Variable表示一个可修改的向量,放置模型参数。可用于计算输入值,也可在计算中被修改。下方代码中,因我们需要让784维的图片向量乘权值w得到十位的证据向量,故模型参数形状为[784,10]。

用全为零的张量初始化w、b:

  1. w=tf.Variable(tf.zeros([784,10]))
  2. b=tf.Variable(tf.zeros[10])

我们需要用784维的图片向量乘w以获得10维的证据值向量,故w维度wei[784, 10](矩阵乘法)。

接着用以下语句实现模型:

  1. y=yf.nn.softmax(tf.matmul(x, W)+b)

在训练模型时,我们定义成本\损失来评估模型好坏;交叉熵是一个常见的成本函数:

  1. Hy‘(y)=-∑ylog(yi)

y’为预测分布,y为实际分布;

使用一个占位符输入正确值,然后计算交叉熵:

  1. y_=tf.placeholder("float", [none, 10])
  2. cross_entropy=-tf.reduce_sum(y_*tf.log(y))

选择梯度算法以0.01的学习速率最小化交叉熵,将每个变量一点点朝成本降低的方向移动:

  1. train_step = tf.train.GradienDescentOptimizer(0.01).minimize(cross_entropy)

简单调整一行代码就可以使用tensorflow提供的其他算法。

最后,初始化模型并训练模型,例子中训练了一千次:

  1. sess=tf.Session()
  2. sess.run(init)
  3. for i in range(1000):
  4. start = (i * BATCH_SIZE) % TOTAL_SIZE
  5. end = start + BATCH_SIZE
  6. batch_xs, batch_ys = x_train[start:end], y_train[start:end]
  7. sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

每次循环随机抓取训练集中的100个批处理点,替换先前定位的占位符,运行train_step。

接下来使用tf.argmax给出tensor对象某维上的数据最大值的索引值,则可知:

  1. tf.argmax(y, 1)返回的是预测标签值
  2. tf.argmax(y_, 1)代表的是正确的标签

使用tf.equal检测预测是否与真实标签匹配:

  1. correct_prediction=tf.equal(tf.argmax(y, 1), tf.argmax(y_,1))

代码返回布尔值,将其转换成浮点数后取平均,计算学习模型正确率:

  1. accuracy=tf.reduce_mean(tf.cast(correct_prediction, "float"))
  2. print(sess.run(accuracy, feed_dict={x:x_test, y_: y_test}))

这样对于mnist手写体识别的过程就写完了。

机器学习的hello world——MNIST的更多相关文章

  1. TensorFlow 学习(3)——MNIST机器学习入门

    通过对MNIST的学习,对TensorFlow和机器学习快速上手. MNIST:手写数字识别数据集 MNIST数据集 60000行的训练数据集 和 10000行测试集 每张图片是一个28*28的像素图 ...

  2. TensorFlow下利用MNIST训练模型并识别自己手写的数字

    最近一直在学习李宏毅老师的机器学习视频教程,学到和神经网络那一块知识的时候,我觉得单纯的学习理论知识过于枯燥,就想着自己动手实现一些简单的Demo,毕竟实践是检验真理的唯一标准!!!但是网上很多的与t ...

  3. Python资料汇总(建议收藏)

    整理汇总,内容包括长期必备.入门教程.练手项目.学习视频. 一.长期必备. 1. StackOverflow,是疑难解答.bug排除必备网站,任何编程问题请第一时间到此网站查找. https://st ...

  4. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  5. Fashion-MNIST:A MNIST-like fashion product database. Benchmark

    Zalando的文章图像的一个数据集包括一个训练集6万个例子和一个10,000个例子的测试集. 每个示例是一个28x28灰度图像,与10个类别的标签相关联. 时尚MNIST旨在作为用于基准机器学习算法 ...

  6. Python入门、练手、视频资源汇总,拿走别客气!

    摘要:为方便朋友,重新整理汇总,内容包括长期必备.入门教程.练手项目.学习视频. 一.长期必备. 1. StackOverflow,是疑难解答.bug排除必备网站,任何编程问题请第一时间到此网站查找. ...

  7. 自学Python,新手上路,好资源免费分享

    Python 可以用来做什么? 在我看来,基本上可以不负责任地认为,Python 可以做任何事情.无论是从入门级选手到专业级选手都在做的爬虫,还是Web 程序开发.桌面程序开发还是科学计算.图像处理, ...

  8. ---转载---phython资料

    整理汇总,内容包括长期必备.入门教程.练手项目.学习视频. 一.长期必备. 1. StackOverflow,是疑难解答.bug排除必备网站,任何编程问题请第一时间到此网站查找. https://st ...

  9. Python学习教程(一)自学资源分享

    Python 可以用来做什么? 在我看来,基本上可以不负责任地认为,Python 可以做任何事情.无论是从入门级选手到专业级选手都在做的爬虫,还是Web 程序开发.桌面程序开发还是科学计算.图像处理, ...

随机推荐

  1. (一)C# Windows Mobile 半透明窗体

    Windows Mobile,个人心中臻至完美的系统. 不忍自己对WM的钻研成果消逝,故留作纪念. 系列开篇,便是一个曾令自己困扰很久的问题:如何实现半透明窗体. 如果了解Win32编程,其实很简单. ...

  2. Epicor RoHS Overview

    Epicor ERP具有一个旨在帮助符合指令2002/95/EC (RoHS1) and 2011/65/EU (RoHS2)的模块,特别适用于医疗设备公司. 不合格的依据是–最大浓度值和合格声明/ ...

  3. 关于DNS解析:侧面剖析

    作为一个合格的重度windows使用用户,我清楚的知道一个文件——hosts文件:C:\Windows\System32\drivers\etc\hosts文件 该文件需要一定的管理员权限. 这个文件 ...

  4. [函数] PHP取二进制文件头快速判断文件类型

    一般我们都是按照文件扩展名来判断文件类型,但其实不太靠谱,因为可以通过修改扩展名来伪装文件类型.其实我们可以通过读取文件信息来识别,比如 PHP扩展中提供了类似 exif_imagetype 这样的函 ...

  5. 日志分析工具ELK(四)

    Logstash收集TCP日志 #Input plugins TCP插件 所需的配置选项 tcp { port =>... } [root@linux-node1 ~]# cat tcp.con ...

  6. vs code 打开文件时,取消文件目录的自动定位跟踪

    文件-->首选项-->设置-->在搜索栏中搜索:explorer.autoReveal;    去掉勾选即可.

  7. 【三剑客】sed命令

    1. Sed 简介 sed 是Stream Editor(流编辑器)的缩写,是操作.过滤和转换文本内容的强大工具.常用功能有增删改查,过滤,取行.   sed 是一种新型的,非交互式的编辑器. 它能执 ...

  8. Alpine Linux 3.9.2 发布,轻量级 Linux 发行版

    开发四年只会写业务代码,分布式高并发都不会还做程序员?   Alpine Linux 3.9.2 已发布,Alpine Linux 是一款面向安全的轻量级 Linux 发行版,体积十分的小. Alpi ...

  9. 使用Hexo框架搭建博客,并部署到github上

    开发背景:年后回来公司业务不忙,闲暇时间了解一下node的使用场景,一篇文章吸引了我15个Nodejs应用场景,然后就被这个hexo框架吸引了,说时迟,那时快,赶紧动手搭建起来,网上找了好多资料一天时 ...

  10. Golang项目部署

    文章来源:https://goframe.org/deploymen... 一.独立部署 使用GF开发的应用程序可以独立地部署到服务器上,设置为后台守护进程运行即可.这种模式常用在简单的API服务项目 ...