mnist手写数字检测
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 23 06:16:04 2019
@author: 92958
"""
import numpy as np
import tensorflow as tf
#下载并载入mnist(55000*28*28图片)
#from tensorflow.examples.tutorials.mnist import input_data
#创造变量mnist,用特定函数,接收
mnist = input_data.read_data_sets('F:\\python\\TensorFlow\\mnist\\mnist_data\\',one_hot=True)
#one_hot独热码,例,:0001000000
#None表示tensor的第一个维度可以是任何长度
input_x = tf.placeholder(tf.float32,[None,28*28])/255. #除255表示255个灰度值
output_y = tf.placeholder(tf.int32,[None, 10]) #10个输出标签
input_x_images = tf.reshape(input_x, [-1,28,28,1]) #改变形状之后的输出
#从Test选3000个数据
test_x = mnist.test.images[:3000]#图片
test_y = mnist.test.labels[:3000]#标签
#日志
path = "F:\\python\\TensorFlow\\mnist\\log"
#构建第一层神经网络
conv1 = tf.layers.conv2d(
inputs=input_x_images, #形状28.28.1
filters =32, #32个过滤器输出深度32
kernel_size=[5,5], #过滤器在二维的大小5*5
strides=1, #步长为1
padding='same', #same表示输出大小不变,因此外围补零两圈
activation=tf.nn.relu #激活函数为relu
)
#输出得到28*28*32
#第一层池化层pooling(亚采样)
pool1 = tf.layers.max_pooling2d(
inputs=conv1, #形状为28*28*32
pool_size=[2,2], #过滤器大小2*2
strides=2, #步长为2
)
#形状14*14*32
#第二层卷积层
conv2 = tf.layers.conv2d(
inputs=pool1, #形状14*14*32
filters =64, #32个过滤器输出深度64
kernel_size=[5,5], #过滤器在二维的大小5*5
strides=1, #步长为1
padding='same', #same表示输出大小不变,因此外围补零两圈
activation=tf.nn.relu #激活函数为relu
)
#形状14*14*64
#第二层池化层pooling(亚采样)
pool2 = tf.layers.max_pooling2d(
inputs=conv2, #形状为14*14*64
pool_size=[2,2], #过滤器大小2*2
strides=2, #步长为2
)
#形状7*7*64
#平坦化(flat)
flat = tf.reshape(pool2,[-1,7*7*64]) #形状7*7*64
#全连接层
dense = tf.layers.dense(inputs = flat,
units=1024, #有1024个神经元
activation=tf.nn.relu#激活函数relu
)
#dropout:丢弃50%,rate=0.5
dropout = tf.layers.dropout(inputs=dense, rate=0.5)
#10个神经元的全连接层,这里不用激活函数来做非线性化
logits=tf.layers.dense(inputs=dropout,units=10)#输出1*1*10
#计算误差,(计算cross entropy(交叉熵),再用softmax计算百分比概率)
loss = tf.losses.softmax_cross_entropy(onehot_labels=output_y,
logits=logits)
#Adam优化器来最小化误差
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
#精度
#返回
accuracy = tf.metrics.accuracy(
labels=tf.argmax(output_y,axis=1),
predictions=tf.argmax(logits,axis=1),)[1]
#创建会话
sess = tf.Session()
#初始化变量全局和局部
init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init)
writer =tf.summary.FileWriter(path,sess.graph)
for i in range(1000):
batch = mnist.train.next_batch(50)
#从train数据集里取下一个50个样本
train_loss,train_op_= sess.run([loss,train_op],
{input_x:batch[0],output_y:batch[1]})
if i%100==0:
test_accuracy = sess.run(accuracy,
{input_x:test_x,output_y:test_y})
print("Step=",i)
print("Train loss=",train_loss)
print("Test accuracy=",test_accuracy)
#测试
test_output=sess.run(logits,{input_x:test_x[:20]})
inferenced_y=np.argmax(test_output,1)
print(inferenced_y,'推测')
print(np.argmax(test_y[:20],1),'真实')
mnist数据集http://yann.lecun.com/exdb/mnist/
mnist手写数字检测的更多相关文章
- 学习OpenCV——SVM 手写数字检测
转自http://blog.csdn.net/firefight/article/details/6452188 是MNIST手写数字图片库:http://code.google.com/p/supp ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- 简单HOG+SVM mnist手写数字分类
使用工具 :VS2013 + OpenCV 3.1 数据集:minst 训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml 数据准备 train-images- ...
- mnist 手写数字识别
mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- Tensorflow可视化MNIST手写数字训练
简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...
随机推荐
- OAuth 2.0授权之授权码授权
OAuth 2.0 是一个开放的标准协议,允许应用程序访问其它应用的用户授权的数据(如用户名.头像.昵称等).比如使用微信.QQ.支付宝登录等第三方网站,只需要用户点击授权按钮,第三方网站就会获取到用 ...
- js设计模式总结1
js设计模式有很多种,知道不代表会用,更不代表理解,为了更好的理解每个设计模式,对每个设计模式进行总结,以后只要看到总结,就能知道该设计模式的作用,以及模式存在的优缺点,使用范围. 本文主要参考张容铭 ...
- java权限控制以及变量的初始化
知识是靠积累的,不断的温习会帮你让你遇到许多问题,解决完这些问题之后,会收获许多,233333333333333. 1.java访问控制符 2.java变量初始化问题 默认构造方法的名字与类名相同,它 ...
- 并发编程 —— Timer 源码分析
前言 在平时的开发中,肯定需要使用定时任务,而 Java 1.3 版本提供了一个 java.util.Timer 定时任务类.今天一起来看看这个类. 1.API 介绍 Timer 相关的有 3 个类: ...
- 在ASP.NET MVC中使用Grid.mvc
很久没有写ASP.NET的博文了,专心工作嘛,今天写一点MVC的博文,也是自己练习来的,是使用grid.mvc来显示数据. 首先打开Manage Nuget Packages,搜索grid.mvc并安 ...
- 控制器中获取Field值
在ASP.NET MVC程序中,我们需要POST Data到制器中,是有很多方法.但是我们想在控制器中,获取Feild值呢?怎样获取?你可以留意到有一个类FormCollection.它能帮助到我们解 ...
- 关于ASPxComboBox通过ClientInstanceName,js获取不到控件的问题
今天突然遇到一个很奇葩的问题 ASPxComboBox中设置了ClientInstanceName.但是通过cmbOrganization.GetValue()获取不到值. 报错cmbOrganiza ...
- NodeJS+Express开发web,为什么中文显示为乱码
把你的文件另存为下,格式为utf-8的试下就行!
- 如何让win2008服务器显示中文无乱码
使用Windows Server 2008 R2 IIS搭建FTP服务器时,客户端登录FTP后中文文件夹显示为乱码,应在“控制面板”-“区域和语言”中查看“当前系统区域设置”的情况. 应确保“非Uni ...
- 【Linux】rpm常用命令及rpm参数介绍
RPM是RedhatPackageManager的缩写,是由RedHat公司开发的软件包安装和管理程序,同Windows平台上的Uninstaller比较类似.使用RPM,用户可以自行安装和管理Lin ...