# -*- coding: utf-8 -*-
import numpy as np
np.random.seed(1337) from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense
from keras.optimizers import Adam TIME_STEPS = 28 #图片的高
INPUT_SIZE = 28 #图片的行
BATCH_SIZE = 50 #每批训练多少图片
BATCH_INDEX = 0
OUTPUT_SIZE = 10
CELL_SIZE = 50
LR = 0.001 #下载mnist数据集
# X shape (60000,28*28) ,y shape (10000)
(X_train,y_train),(X_test,y_test) = mnist.load_data() # 数据预处理
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10) # 建模型
model = Sequential()
# RNN
model.add(SimpleRNN(
batch_input_shape=(None,TIME_STEPS,INPUT_SIZE),# 每次训练的量(None表示全部),图片大小
output_dim=CELL_SIZE,
))
# 输出层
model.add(Dense(OUTPUT_SIZE))
model.add(Activation('softmax')) # 优化器
adam = Adam(LR)
model.compile(optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy']) # 训练
for step in range(4001):
X_batch=X_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:,:]
Y_batch=y_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:]
cost = model.train_on_batch(X_batch,Y_batch) BATCH_INDEX += BATCH_SIZE
BATCH_INDEX = 0 if BATCH_INDEX>=X_train.shape[0] else BATCH_INDEX if step % 500 == 0:
cost,accuracy = model.evaluate(X_test,y_test,batch_size=y_test.shape[0],verbose=False)
print('test cost: ',cost,'test accuracy: ',accuracy)

用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)的更多相关文章

  1. 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  2. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  3. 吴裕雄 python 神经网络——TensorFlow 卷积神经网络手写数字图片识别

    import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_N ...

  4. 用Keras搭建神经网络 简单模版(二)——Classifier分类(手写数字识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  5. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  6. RNN探索(2)之手写数字识别

    这篇博文不介绍基础的RNN理论知识,只是初步探索如何使用Tensorflow,之后会用笔推导RNN的公式和理论,现在时间紧迫所以先使用为主~~ import numpy as np import te ...

  7. 用tensorflow求手写数字的识别准确率 (简单版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = in ...

  8. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 100天搞定机器学习|day39 Tensorflow Keras手写数字识别

    提示:建议先看day36-38的内容 TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操作,图中的线(edge ...

随机推荐

  1. 关于阿里云OSS上传图片之后会被旋转90度的解决办法

    原文:https://www.cnblogs.com/wuhjbk/p/10133596.html 问题描述:正常的图片前端上传到oss成功之后的资源地址.在html上引用的时候被旋转了90度oss资 ...

  2. 用js刷剑指offer(变态跳台阶)

    一只青蛙一次可以跳上1级台阶,也可以跳上2级……它也可以跳上n级.求该青蛙跳上一个n级的台阶总共有多少种跳法. 牛客网链接 思路 假设青蛙跳上一个n级的台阶总共有f(n)种跳法. 现在青蛙从第n个台阶 ...

  3. tp5 左连接

    db('detainform')->alias('d')->join("information i",'i.z_id=d.z_id','LEFT')->where ...

  4. gdb设置条件断点

    b +行号 if i==9:设置条件断点 finish:执行到当前函数返回处(退出函数) bt:打印栈帧关系

  5. go语言合并两个数组

    https://stackoverflow.com/questions/16248241/concatenate-two-slices-in-go Add dots after the second ...

  6. 大数据之路week03--day05(线程 I)

    真的,身体这个东西一定要爱护好,难受的时候电脑都不想去碰,尤其是胃和肾... 这两天耽误了太多时间,今天好转了立刻学习,即刻不能耽误!. 话不多说,说正事: 1.多线程(理解) (1)多线程:一个应用 ...

  7. Http协议与TCP协议

    背景 在日常工作中,经常会遇到某某框架是基于Http协议或者TCP协议,今天,就针对于该协议,整理下 从本质上来说,Http协议与TCP协议是应用在不同网络层,Http协议处于应用层,TCP处于传输层 ...

  8. 洛谷1546 最短网络Agri-Net【最小生成树】【prim】

    [内含最小生成树Prim模板] 题目:https://www.luogu.org/problemnew/show/P1546 题意:给定一个邻接矩阵.求最小生成树. 思路:点少边多用Prim. Pri ...

  9. CF70E Information Reform

    题意:给你一棵树,要选择若干节点,若一个点i没有选择,则有\(d(dis(i,j))\)的代价,其中j被选择.选择一个点代价为k,求最小代价. 首先,考虑这样一个问题: 如果距离a的最近被选点为i,距 ...

  10. (RERERERERERERERERERERE) BZOJ 2746: [HEOI2012]旅行问题

    二次联通门 : BZOJ 2746: [HEOI2012]旅行问题 神TM STL的vector push_back进一个数后取出时就变成了一个很小的负数.. 调不出来了, 不调了 #include ...