LSTM是RNN的一种算法, 在序列分类中比较有用。常用于语音识别,文字处理(NLP)等领域。

等同于VGG等CNN模型在在图像识别领域的位置。  本篇文章是叙述LSTM 在MNIST 手写图中的使用。

用来给初步学习RNN的一个范例,便于学习和理解LSTM .

先把工作流程图贴一下

代码片段

数据准备

  1. def makedata():
  2. img_rows, img_cols = 28, 28
  3.  
  4. mnist = fetch_mldata("MNIST original")
  5. # rescale the data, use the traditional train/test split
  6. X_1D, y_int = mnist.data / 255., mnist.target
  7. y = np_utils.to_categorical(y_int, num_classes=10)
  8.  
  9. X = X_1D.reshape(X_1D.shape[0], img_rows, img_cols )
  10.  
  11. input_shape = (img_rows, img_cols, 1)
  12. x_train, x_test = X[:60000], X[60000:]
  13. y_train, y_test = y[:60000], y[60000:]
  14.  
  15. return X, y
  16. pass

下载 MNIST数据, 进行归一化  mnist.data / 255, 把数据[7000,784 ] 转成[ 70000,28,28]

构建模型:

  1. def buildlstm():
  2.  
  3. import numpy as np
  4.  
  5. data_dim = 28
  6. timesteps = 28
  7. num_classes = 10
  8.  
  9. # expected input data shape: (batch_size, timesteps, data_dim)
  10. model = Sequential()
  11. model.add(LSTM(32, return_sequences=True, input_shape=(timesteps, data_dim+14)))
  12. model.add(LSTM(32, return_sequences=True))
  13. model.add(LSTM(32))
  14. model.add(Dense(10, activation='softmax'))
  15.  
  16. model.compile(loss='categorical_crossentropy',
  17. optimizer='rmsprop',
  18. metrics=['accuracy'])
  19. print model.summary()
  20. return model
  21. pass

基础参数: data_dim, timesteps, num_classes   分别为 28,28, 10
网络层级 :    LSTM ----》LSTM ----》LSTM ----》Dense
注意点: input_shape=(timesteps, data_dim+14))   此处 应该为  data_dim , data_dim+14是我做第二个试验使用。
网络理解: RNN是用前一部分数据对当前数据的影响,并共同作用于最后结果。 用基础的深度神经网络(只有Dense层),是把MNIST一个图形,
提取成784个像素数据,把784个数据扔给神经网络,784个数据是同等的概念。 训练出权重来确定最终的分类值。

RNN 之于MNIST, 是把MNIST 分成 28x28 数据。可以理解为用一个激光扫描一个图片,扫成28个(行)数据, 每行为28个像素。 站在时间序列
的角度,其实图片没有序列概念。但是我们可以这样理解, 每一行于下一行是有位置关系的,不能进行顺序变化。 比如一个手写 “7”字, 如果把28行
的上下行顺序打乱, 那么7 上面的一横就可能在中间位置,也可能在下面的位置。  这样,最终的结果就不应该是 7 .  
所以MNIST 的 28x28可以理解为 有时序关系的数据。

训练预测:

  1. def runTrain(model, x_train, x_test, y_train, y_test):
  2. model.fit(x_train, y_train, batch_size= nbatch_size, epochs= nEpoches)
  3. score = model.evaluate(x_test, y_test, batch_size=nbatch_size)
  4. print 'evaluate score:', score
  5. pass

这部分应该没什么好说的

主程序:

  1. def test():
  2.  
  3. X,y = makedata2()
  4. x_train, x_test = X[:60000], X[60000:]
  5. y_train, y_test = y[:60000], y[60000:]
  6. model = buildlstm()
  7. runTrain(model, x_train, x_test, y_train, y_test )
  8. pass

运行结果

  1. 结构:
  2. Layer (type) Output Shape Param #
  3. =================================================================
  4. lstm_1 (LSTM) (None, 28, 32) 7808
  5. _________________________________________________________________
  6. lstm_2 (LSTM) (None, 28, 32) 8320
  7. _________________________________________________________________
  8. lstm_3 (LSTM) (None, 32) 8320
  9. _________________________________________________________________
  10. dense_1 (Dense) (None, 10) 330
  11. =================================================================
  12. Total params: 24,778
  13. Trainable params: 24,778
  14. Non-trainable params: 0
  15. _________________________________________________________________
  16.  
  17. 结果:
  18. base lstm for mnist
  19. acc : 98.56%
  20.  
  21. 结果2
  22. 把数据最后增加 50% 0 (dim X 0.5)
  23. acc : 98.39%
  24. 结果基本上 与原数据一致

该实验证明两个结论:
1.  LSTM可用于图形识别
2.  在数据中 每行28个基础像素后面 + 14 个空白(0)的元素,不影分类识别。

写在最后:  本实验的目的是为了理解RNN(LSTM),  只有理解了才能很好的使用。 本文章的目的是为记录和分享。
再说下 RNN在其它领域的应用。  比如在语音识别领域,一个音谱,识别成一个单词(词语),可以理解成一个
竖向扫描的MNIST ,   一个股票的K线图,也可以理解一个竖向扫描的MNIST。  还有其它领域,可以归纳递推。 
入门之后, 如何在自己的领域,再深入(构建复杂模型,优化数据的处理),提高网络模型的识别准确,那需要
见仁见智的。

代码文件链接:

源码下载

有对 金融程序化 和 深度学习结合有兴趣的可以加群 , 个人群: 杭州程序化交易群  375129936

用LSTM分类 MNIST的更多相关文章

  1. NLP用CNN分类Mnist,提取出来的特征训练SVM及Keras的使用(demo)

    用CNN分类Mnist http://www.bubuko.com/infodetail-777299.html /DeepLearning Tutorials/keras_usage 提取出来的特征 ...

  2. tensorflow学习笔记————分类MNIST数据集

    在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题. 一般是通过使用tensorflow内置的函数进行下载和加载, from tensorflow.examp ...

  3. 【转载】用Scikit-Learn构建K-近邻算法,分类MNIST数据集

    原帖地址:https://www.jiqizhixin.com/articles/2018-04-03-5 K 近邻算法,简称 K-NN.在如今深度学习盛行的时代,这个经典的机器学习算法经常被轻视.本 ...

  4. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  5. LSTM用于MNIST手写数字图片分类

    按照惯例,先放代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 ...

  6. 检测用户命令序列异常——使用LSTM分类算法【使用朴素贝叶斯,类似垃圾邮件分类的做法也可以,将命令序列看成是垃圾邮件】

    通过 搜集 Linux 服务器 的 bash 操作 日志, 通过 训练 识别 出 特定 用户 的 操作 习惯, 然后 进一步 识别 出 异常 操作 行为. 使用 SEA 数据 集 涵盖 70 多个 U ...

  7. 分类-MNIST(手写数字识别)

    这是学习<Hands-On Machine Learning with Scikit-Learn and TensorFlow>的笔记,如果此笔记对该书有侵权内容,请联系我,将其删除. 这 ...

  8. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  9. TensorFlow技术解析与实战学习笔记(15)-----MNIST识别(LSTM)

    一.任务:采用基本的LSTM识别MNIST图片,将其分类成10个数字. 为了使用RNN来分类图片,将每张图片的行看成一个像素序列,因为MNIST图片的大小是28*28像素,所以我们把每一个图像样本看成 ...

随机推荐

  1. 获取AJAX加载的内容

    1.有些网页内容使用AJAX加载,AJAX一般返回的是JSON,直接对AJAX地址进行post或get,就返回JSON数据了. 2.用抓包工具分析https://movie.douban.com/j/ ...

  2. 爬取豆瓣电影储存到数据库MONGDB中以及反反爬虫

    1.代码如下: doubanmoive.py # -*- coding: utf-8 -*- import scrapy from douban.items import DoubanItem cla ...

  3. 导入maven项目时出现 Version of Spring Facet could not be detected. 解决方法

    问题出现在: 导入maven项目的时候,其中,我的这个maven项目是由Spring,Struts2,Mybatis搭建的. 问题截图:  即Spring的版本不能被检测到.此时需要做的就是找到spr ...

  4. 项目(1)----用户信息管理系统(5)---(剩余jsp界面)

    完成剩余jsp界面 首页界面前面我写了,接下来还有就是一个显示所有用户界面 1:注册界面 2:显示所有用户信息界面 1:注册界面 <%@ page language="java&quo ...

  5. Spark配置参数优先级

    1.Properties set directly on the SparkConf take highest precedence, 2.then flags passed to spark-sub ...

  6. 教程:安装禅道zentao项目管理软件github上的开发版

    该文章转自:吕滔博客 直接从github拉下来的禅道的源码,是跑不起来的.除非你按我的教程来做...哈哈哈(不要脸)~~~~ 禅道官网提供的版本包是带了有安装文件,并有打包合成一些css.js文件的. ...

  7. 【java设计模式】【行为模式Behavioral Pattern】迭代器模式Iterator Pattern

    package com.tn.pattern; public class Client { public static void main(String[] args) { Object[] objs ...

  8. Principle-初步认识(简介)

    Principle官网 探究了一下 . 呃--作出了下边这玩意 做的好的是这样的,瞬间把自己给菜了,给大家看看,设计需要UI功夫啊 把这个用上你的界面就搞基了,图形在水平.垂直上的动态效果(*.*) ...

  9. iOS通用链接(Universal Links)突然点击无效的解决方案

    接上文<微信中通过页面(H5)直接打开本地app的解决方案>已经把iOS搞定并且已经正常能跑了,突然就再也用不了了... 问题描述 测试告诉我,如果从微信打开App之后,点击App右上角的 ...

  10. 尝试在条件“$(_DeviceSdkVersion) >= 21”中对计算结果为“”而不是数字的“$(_DeviceSdkVersion)

    晚上搞xamarin ,运行xamarin项目好好的,不知道怎么回事,一次运行xamarin android项目的时候,部署失败,以前也是遇到这样的错误. 尝试在条件"$(_DeviceSd ...