一、TensorFlow简介

TensorFlow是由谷歌开发的一套机器学习的工具,使用方法很简单,只需要输入训练数据位置,设定参数和优化方法等,TensorFlow就可以将优化结果显示出来,节省了很大量的编程时间,TensorFlow的功能很多很强大,这边挑选了一个比较简单实现的方法,就是利用TensorFlow的逻辑回归算法对数据库中的手写数字做识别,让机器找出规律,然后再导入新的数字让机器识别。

二、流程介绍

上图是TensorFlow的流程,可以看到一开始要先将参数初始化,然后导入训练数据,计算偏差,然后修正参数,再导入新的训练数据,不断重复,当数据量越大,理论上参数就会越准确,不过也要注意不可训练过度。

三、导入数据

数据可进入MNIST数据库 (Mixed National Institute of Standards and Technology database),这是一个开放的数据库,里面有许多免费的训练数据可以提供下载,这次我们要下载的是手写的阿拉伯数字,为什么要阿拉伯数字呢?1、因为结果少,只有十个,比较好训练 2、图片的容量小,不占空间,下面是部分的训练数据案例

TensorFlow可以直接下载MNIST上的训练数据,并将它导入使用,下面为导入数据的代码

from tensorflow.examples.tutorials.mnist import input_data
MNIST = input_data.read_data_sets("/data/mnist", one_hot=True)

四、设定参数

接下来就是在TensorFlow里设定逻辑回归的参数,我们知道回归的公式为Y=w*X+b,X为输入,Y为计算结果,w为权重参数,b为修正参数,其中w和b就是我们要训练修正的参数,但训练里要怎么判断计算结果好坏呢?就是要判断计算出来的Y和实际的Y损失值(loss)是多少,并尽量减少loss,这边我们使用softmax函数来计算,softmax函数在计算多类别分类上的表现比较好,有兴趣可以百度一下,这边就不展开说明了,下面为参数设定

X = tf.placeholder(tf.float32, [batch_size, 784], name="image")
Y = tf.placeholder(tf.float32, [batch_size, 10], name="label")

X为输入的图片,图片大小为784K,Y为实际结果,总共有十个结果(数字0-9)

w = tf.Variable(tf.random_normal(shape=[784, 10], stddev=0.01), name="weights")
b = tf.Variable(tf.zeros([1, 10]), name="bias")

w初始值为一个随机的变数,标准差为0.01,b初始值为0。

logits = tf.matmul(X, w) + b
entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y)
loss = tf.reduce_sum(entropy)

TensorFlow里面已经有softmax的函数,只要把他叫出来就可以使用。

optimizer =
tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
n_batches = int(MNIST.train.num_examples/batch_size)
for i in range(n_epochs): # train the model n_epochs times
for _ in range(n_batches):
X_batch, Y_batch = MNIST.train.next_batch(batch_size)
sess.run([optimizer, loss], feed_dict={X: X_batch, Y:Y_batch})

接着就是设定优化方式,这边是使用梯度降下发,然后将参数初始化,接着就运行了,这边要提一下,我们的训练方式是每次从训练数据里面抓取一个batch的数据,然后进行计算,这样可以预防过度训练,也比较可以进行事后的验证,运行完后再用下面的代码进行验证

    n_batches = int(MNIST.test.num_examples/batch_size)
total_correct_preds = 0
for i in range(n_batches):
X_batch, Y_batch = MNIST.test.next_batch(batch_size)
_, loss_batch, logits_batch = sess.run([optimizer, loss, logits],
feed_dict={X: X_batch, Y:Y_batch})
preds = tf.nn.softmax(logits_batch)
correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y_batch, 1))
accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))
total_correct_preds += sess.run(accuracy)
print ("Accuracy {0}".format(total_correct_preds/MNIST.test.num_examples))

最后shell跑出来的结果是0.916,虽然看上去还算是不错的结果,但其实准确率是很低的,因为他验证的方式是判断一个图片是否为某个数字(单输出),所以假如机器随便猜也会有0.82左右的命中几率(0.9*0.9+0.1*0.1),想要更准确的话目前想到有两个方向,一个是提高训练量和增加神经网络的层数。

用TensorFlow做图像识别(python)的更多相关文章

  1. 03 使用Tensorflow做计算题

    我们使用Tensorflow,计算((a+b)*c)^2/a,然后求平方根.看代码: import tensorflow as tf # 输入储存容器 a = tf.placeholder(tf.fl ...

  2. 【转载】 深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方,

    原文地址: https://blog.csdn.net/weixin_40759186/article/details/87547795 ------------------------------- ...

  3. 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  4. Tensorflow学习(练习)—使用inception做图像识别

    import osimport tensorflow as tfimport numpy as npimport re from PIL import Imageimport matplotlib.p ...

  5. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. [树莓派(raspberry pi)] 02、PI3安装openCV开发环境做图像识别(详细版)

    前言 上一篇我们讲了在linux环境下给树莓派安装系统及入门各种资料 ,今天我们更进一步,尝试在PI3上安装openCV开发环境. 博主在做的过程中主要参考一个国外小哥的文章(见最后链接1),不过其教 ...

  7. Tensorflow做阅读理解与完形填空

    catalogue . 前言 . 使用的数据集 . 数据预处理 . 训练 . 测试模型运行结果: 进行实际完形填空 0. 前言 开始写这篇文章的时候是晚上12点,突然想到几点新的理解,赶紧记下来.我们 ...

  8. tensorflow 学习1——tensorflow 做线性回归

    . 首先 Numpy: Numpy是Python的科学计算库,提供矩阵运算. 想想list已经提供了矩阵的形式,为啥要用Numpy,因为numpy提供了更多的函数. 使用numpy,首先要导入nump ...

  9. TensorFlow 安装以及python虚拟环境

    python虚拟环境 由于TensorFlow只支持某些版本的python解释器,如Python3.6.如果其他版本用户要使用TensorFlow就必须安装受支持的python版本.为了方便在不同项目 ...

随机推荐

  1. Uniapp使用iconfont

    看别人的项目有各种各样的图标既好看占用内存还小 后来才知道原来有icon图标这个东西,原谅我真的一直处于混沌的状态. 刚好最近项目使用了uniapp框架,引入iconfont的方式和之前有些不太一样 ...

  2. python基础之包的导入

    包的导入 python是一门灵活性的语言 ,也可以说python是一门胶水语言,顾名思义,就是可一导入各类的包, python的包可是说是所有语言中最多的.当然导入包大部分是为了更方便,更简便,效率更 ...

  3. ubuntu netstat 查看端口占用情况

    netstat 用于显示各种网络相关信息,如网络连接,路由表,接口状态 (Interface Statistics),masquerade 连接,多播成员 (Multicast Memberships ...

  4. Java多线程遍历文件夹,广度遍历加多线程加深度遍历结合

    复习IO操作,突然想写一个小工具,统计一下电脑里面的Java代码量还有注释率,最开始随手写了一个递归算法,遍历文件夹,比较简单,而且代码层次清晰,相对易于理解,代码如下:(完整代码贴在最后面,前面是功 ...

  5. Logback设置SQL参数打印

    一.hibernate中设置SQL参数打印: (主要是第一句) <logger name="org.hibernate.type.descriptor.sql.BasicBinder& ...

  6. 2005年NOIP普及组复赛题解

    题目涉及算法: 陶陶摘苹果:入门题: 校门外的树:简单模拟: 采药:01背包: 循环:模拟.高精度. 陶陶摘苹果 题目链接:https://www.luogu.org/problem/P1046 循环 ...

  7. Java Integer类的缓存

    首先看一段代码(使用JDK 5),如下: public class Hello { public static void main(String[] args) { int a = 1000, b = ...

  8. Python--day23--类的命名空间

    当创建一个对象时,就会在内存中分出一块新的空间存放这个对象的属性,这块空间也叫类的命名空间.里面存放着类对象指针可以找到类.

  9. H3C 高级ACL部署位置示例

  10. Codeforces Round #186 (Div. 2)

    A. Ilya and Bank Account 模拟. B. Ilya and Queries 前缀和. C. Ilya and Matrix 考虑每个元素的贡献. 边长为\(2^n\)时,贡献为最 ...