莫烦大大keras学习Mnist识别(4)-----RNN
一、步骤:
导入包以及读取数据
设置参数
数据预处理
构建模型
编译模型
训练以及测试模型
二、代码:
1、导入包以及读取数据
#导入包
import numpy as np
np.random.seed(1337) #设置之后每次执行代码,产生的随机数都一样 from tensorflow.examples.tutorials.mnist import input_data
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN , Activation , Dense
from keras.optimizers import Adam #读取数据
mnist = input_data.read_data_sets('E:\jupyter\TensorFlow\MNIST_data',one_hot = True)
X_train = mnist.train.images
Y_train = mnist.train.labels
X_test = mnist.test.images
Y_test = mnist.test.labels
2、设置参数
#设置参数
time_steps = 28 # same as the height of the image
input_size = 28 # same as the width of the image
batch_size = 50
batch_index = 0
output_size = 10
cell_size = 50
lr = 0.001
3、数据预处理
#数据预处理
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
4.1、构建RNN模型
#构建模型
model = Sequential() #RNN层
model.add(SimpleRNN(
batch_input_shape =(None,time_steps,input_size), # 输入维度
output_dim = cell_size, #输出维度
)) #输出层
model.add(Dense(output_size))
model.add(Activation('softmax'))
4.2、构建LSTM模型
def builtLSTMModel(time_steps,input_size,cell_size,output_size):
model = Sequential() #添加LSTM
model.add(LSTM(
batch_input_size = (None,time_steps,input_size),
output_dim = cell_size,
return_sequences = True, #要不要每个时间点的输出都输出
stateful = True, #batch和batch有联系,batch和batch之间的状态需要连接起来
)) #添加输出层
model.add(TimeDistributed(Dense(output_size))) #每一个时间点的输出都要加入全连接层。
return model
5、训练模型以及测试
#训练模型
for step in range(4001): X_batch = X_train[batch_index:batch_size + batch_index,:,:]
Y_batch = Y_train[batch_index:batch_size + batch_index,:]
cost = model.train_on_batch(X_batch,Y_batch) batch_index += batch_size
batch_index = 0 if batch_index >= X_train.shape[0] else batch_index if step % 500 ==0:
loss , acc = model.evaluate(X_test,Y_test,batch_size =Y_test.shape[0])
print(loss,',',acc)
莫烦大大keras学习Mnist识别(4)-----RNN的更多相关文章
- 莫烦大大keras学习Mnist识别(3)-----CNN
一.步骤: 导入模块以及读取数据 数据预处理 构建模型 编译模型 训练模型 测试 二.代码: 导入模块以及读取数据 #导包 import numpy as np np.random.seed(1337 ...
- 莫烦大大keras的Mnist手写识别(5)----自编码
一.步骤: 导入包和读取数据 数据预处理 编码层和解码层的建立 + 构建模型 编译模型 训练模型 测试模型[只用编码层来画图] 二.代码: 1.导入包和读取数据 #导入相关的包 import nump ...
- 莫烦大大TensorFlow学习笔记(9)----可视化
一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...
- 莫烦python教程学习笔记——总结篇
一.机器学习算法分类: 监督学习:提供数据和数据分类标签.--分类.回归 非监督学习:只提供数据,不提供标签. 半监督学习 强化学习:尝试各种手段,自己去适应环境和规则.总结经验利用反馈,不断提高算法 ...
- 莫烦大大TensorFlow学习笔记(8)----优化器
一.TensorFlow中的优化器 tf.train.GradientDescentOptimizer:梯度下降算法 tf.train.AdadeltaOptimizer tf.train.Adagr ...
- 莫烦python教程学习笔记——保存模型、加载模型的两种方法
# View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...
- 莫烦python教程学习笔记——validation_curve用于调参
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
- 莫烦python教程学习笔记——learn_curve曲线用于过拟合问题
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
- 莫烦python教程学习笔记——利用交叉验证计算模型得分、选择模型参数
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
随机推荐
- [转]十五天精通WCF——第二天 告别烦恼的config配置
经常搞wcf的基友们肯定会知道,当你的应用程序有很多的“服务引用”的时候,是不是有一种疯狂的感觉...从一个环境迁移到另外一个环境,你需要改变的 endpoint会超级tmd的多,简直就是搞死了人.. ...
- [Cypress] Create Aliases for DOM Elements in Cypress Tests
We’ll often need to access the same DOM elements multiple times in one test. Your first instinct mig ...
- hdu1068 Girls and Boys --- 最大独立集
有一个集合男和一个集合女,给出两集合间一些一一相应关系.问该两集合中的最大独立集的点数. 最大独立集=顶点总数-最大匹配数 此题中.若(a,b)有关.则(b,a)有关.每个关系算了两次,相当于二分图的 ...
- iOS 通讯录编程【总结】
第一大块儿:读取通讯录 1.iOS 6以上系统,争取获取用户允许: 初始化的时候须要推断.设备是否授权 -(id)init{ self = [super init]; [self createdABH ...
- SpringMVC案例2----基于spring2.5的注解实现
和上一篇一样,首先看一下项目结构和jar包 watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvYmVuamFtaW5fd2h4/font/5a6L5L2T/fo ...
- webService总结(一)——使用CXF公布和调用webService(不使用Spring)
CXF和Axis2是两个比較流行的webService框架,接下来我会写几篇博客简介怎样使用这两种框架. 首先,先简介一下CXF的使用. CXF公布webService有多种方法.这里我介绍三种: 1 ...
- React-Router 中文简明教程(上)
概述 说起 前端路由,如果你用过前端 MV* 框架构建 SPA 应用(单页面应用),对此一定不陌生. 传统开发中的 路由,是由服务端根据不同的用户请求地址 URL,返回不同内容的页面,而前端路由则将这 ...
- C++对象内存布局 (一)
一.前言 在看了皓哥的C++对象的布局之后:http://blog.csdn.net/haoel/article/details/3081328.感觉自己还是总是不能记得很清楚,故在此总结一下C++对 ...
- (转)dp动态规划分类详解
dp动态规划分类详解 转自:http://blog.csdn.NET/cc_again/article/details/25866971 动态规划一直是ACM竞赛中的重点,同时又是难点,因为该算法时间 ...
- 时序数据库深入浅出之存储篇——本质LSMtree,同时 metric(比如温度)+tags 分片
什么是时序数据库 先来介绍什么是时序数据.时序数据是基于时间的一系列的数据.在有时间的坐标中将这些数据点连成线,往过去看可以做成多纬度报表,揭示其趋势性.规律性.异常性:往未来看可以做大数据分析,机器 ...