LSTM用于MNIST手写数字图片分类
按照惯例,先放代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) # 输入图片是28*28
n_inputs = 28 #输入一行,一行有28个数据
max_time = 28 #一共28行
lstm_size = 100 #隐层单元
n_classes = 10 # 10个分类
batch_size = 50 #每批次50个样本
n_batch = mnist.train.num_examples // batch_size #计算一共有多少个批次 #这里的none表示第一个维度可以是任意的长度
x = tf.placeholder(tf.float32,[None,784])
#正确的标签
y = tf.placeholder(tf.float32,[None,10]) #初始化权值
weights = tf.Variable(tf.truncated_normal([lstm_size, n_classes], stddev=0.1))
#初始化偏置值
biases = tf.Variable(tf.constant(0.1, shape=[n_classes])) #定义RNN网络
def RNN(X,weights,biases):
# inputs=[batch_size, max_time, n_inputs]
inputs = tf.reshape(X,[-1,max_time,n_inputs])
#定义LSTM基本CELL
lstm_cell = tf.nn.rnn_cell.LSTMCell(lstm_size)
#lstm_cell = tf.contrib.rnn.LSTMCell(lstm_size, name='basic_lstm_cell')
#final_state[state, batch_size, cell.state_size]
# final_state[0]是cell state
# final_state[1]是hidden_state
#outputs: The RNN output 'Tensor'
#If time_major == False(default), this will be a 'Tensor' shaped:
# [batch_size, max_time, cell.output_size]
#If time_major == True, this will be a 'Tensor' shaped:
# [max_time, batch_size, cell.output_size]
outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
results = tf.nn.softmax(tf.matmul(final_state[1],weights) + biases)
return results #计算RNN的返回结果
prediction= RNN(x, weights, biases)
#损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
#使用AdamOptimizer进行优化
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#把correct_prediction变为float32类型
#初始化
init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
for epoch in range(50):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print ("Iter " + str(epoch) + ", Testing Accuracy= " + str(acc))
LSTM用于MNIST手写数字图片分类的更多相关文章
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- MNIST手写数字分类simple版(03-2)
simple版本nn模型 训练手写数字处理 MNIST_data数据 百度网盘链接:https://pan.baidu.com/s/19lhmrts-vz0-w5wv2A97gg 提取码:cgnx ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)
# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...
随机推荐
- docker安装配置mongodb
1 执行 docker search mongo 命令: 2 运行mongo docker run --name mongo -v /mnt/mongodb:/data/db -p 27017:270 ...
- python大佬养成计划----HTML网页设计(表格)
制作网页时,要合理规划网页布局.比如,在网页中添加一个表格,可分为上.中.下三部分,上部存放网页标题或LOGO图片,中间部分是整个网页的主体内容,底部就是相关制作信息.此外,单元格里还可再添加单元格, ...
- gdb命令小结
GDB命令小结 gdb <filename> : 调试指定程序文件 r : run 的简写,运行被调试程序, 如果此前没有下过断点,则执行完整个程序:如果有断点, 则程序暂停在第一个可用断 ...
- 微信小程序-饮食日志_开发记录01
今天主要了解微信小程序的框架结构以及环境部署等. 小程序的框架主要分为: js.json.wxss.wxml等 和java web的内容相似,主要了解内部代码的使用情况和语言方式. 主要写了页面的框架 ...
- 【Luogu5294】[HNOI2019]序列
题目链接 题意 给定一个序列,要求将它改造成一个非降序列,修改一个数的代价为其改变量的平方. 最小化总代价. 另有\(Q\) 次询问,每次修改一个位置上的数.(询问之间独立,互不影响) Sol 神仙 ...
- Gym-100923L-Por Costel and the Semipalindromes(进制转换,数学)
链接: https://vjudge.net/problem/Gym-100923L 题意: Por Costel the pig, our programmer in-training, has r ...
- 对ECMAScript的研究-----------引用
ECMAScript 新特性与标准提案 一:ES 模块 第一个要介绍的 ES 模块,由于历史上 JavaScript 没有提供模块系统,在远古时期我们常用多个 script 标签将代码进行人工隔离.但 ...
- java总结1
栈,堆,方法区.main和局部变量在栈,new 对象 在堆, 类和常量在方法区除了8大基础数据类型,其他都为引用变量局部变量在函数内或方法上声明,没有默认值,定义必须赋值一旦提供构造方法,就不会有默认 ...
- c++ STL map 用法
map是STL的一个关联容器,它提供一对一(其中第一个可以称为关键字,每个关键字只能在map中出现一次,第二个可能称为该关键字的值)的数据 处理能力,由于这个特性,它完成有可能在我们处理一对一数据的时 ...
- Tree and Permutation
Tree and Permutation 给出一个1,2,3...N的排列,显然全部共有N!种排列,每种排列的数字代表树上的一个结点,设Pi是其中第i种排列的相邻数字表示的结点的距离之和,让我们求su ...