莫烦Python 4

新建模板小书匠

RNN Classifier 循环神经网络

问题描述

使用RNN对MNIST里面的图片进行分类

关键

SimpleRNN()参数

  • batch_input_shape

    使用状态RNN的注意事项

可以将RNN设置为‘stateful’,意味着由每个batch计算出的状态都会被重用于初始化下一个batch的初始状态。状态RNN假设连续的两个batch之中,相同下标的元素有一一映射关系。

要启用状态RNN,请在实例化层对象时指定参数stateful=True,并在Sequential模型使用固定大小的batch:通过在模型的第一层传入batch_size=(…)和input_shape来实现。在函数式模型中,对所有的输入都要指定相同的batch_size。

如果要将循环层的状态重置,请调用.reset_states(),对模型调用将重置模型中所有状态RNN的状态。对单个层调用则只重置该层的状态。

(samples,timesteps,input_dim)

代码

'''
RNN Classifier 循环神经网络
'''
import numpy as np
np.random.seed(1337) from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN, Activation, Dense
from keras.optimizers import Adam time_step = 28
input_size = 28
batch_size = 50
output_size = 10
cell_size = 50
LR = 0.001 (X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = X_train.reshape(-1, 28, 28) / 255. # normalize
X_test = X_test.reshape(-1, 28, 28) / 255. # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10) model = Sequential()
model.add(
SimpleRNN(
batch_input_shape=(None, time_step, input_size),
units=cell_size
)
) model.add(
Dense(output_size)
) model.add(Activation('softmax')) adam = Adam(LR)
model.compile(
optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary() model.fit(X_train, y_train, batch_size=batch_size, epochs=2, verbose=2, validation_data=(X_test, y_test))

结果

Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_2 (SimpleRNN) (None, 50) 3950
_________________________________________________________________
dense_2 (Dense) (None, 10) 510
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
=================================================================
Total params: 4,460
Trainable params: 4,460
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
- 12s - loss: 0.6643 - accuracy: 0.7966 - val_loss: 0.4501 - val_accuracy: 0.8550
Epoch 2/2
- 9s - loss: 0.3220 - accuracy: 0.9087 - val_loss: 0.2445 - val_accuracy: 0.9359

莫烦Python 4的更多相关文章

  1. 莫烦python教程学习笔记——保存模型、加载模型的两种方法

    # View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...

  2. 莫烦python教程学习笔记——validation_curve用于调参

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  3. 莫烦python教程学习笔记——learn_curve曲线用于过拟合问题

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  4. 莫烦python教程学习笔记——利用交叉验证计算模型得分、选择模型参数

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  5. 莫烦python教程学习笔记——数据预处理之normalization

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  6. 莫烦python教程学习笔记——线性回归模型的属性

    #调用查看线性回归的几个属性 # Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg # ...

  7. 莫烦python教程学习笔记——使用波士顿数据集、生成用于回归的数据集

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  8. 莫烦python教程学习笔记——使用鸢尾花数据集

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  9. 莫烦python课程里面的bug修复;课程爬虫小练习爬百度百科

    我今天弄了一下午修改这个代码,最后还是弄好了.原因是正则表达式的筛选不够准确,有时候是会带http:baidu这些东西的.所以需要一个正则表达式的断言,然后还有一点是如果his里面只有一个元素就不要再 ...

  10. Tensorflow 教程系列 | 莫烦Python

    Tensorflow 简介 1.1 科普: 人工神经网络 VS 生物神经网络 1.2 什么是神经网络 (Neural Network) 1.3 神经网络 梯度下降 1.4 科普: 神经网络的黑盒不黑 ...

随机推荐

  1. React Refs-知识点整理记录

    一.Refs的作用 通过Refs,可以访问到 1. DOM节点. 2. render方法中创建的React元素.(class组件的实例) 二.访问节点或者实例有什么用?为什么要使用Refs来访问? 访 ...

  2. Pytest插件之pytest-base-url切换测试环境

    Pytest插件之pytest-base-url切换测试环境 安装  pip install pytest-base-url 应用场景 利用参数--base-url或者配置(pytest.ini中ba ...

  3. 一次代码重构 JavaScript 图连通性判定

    简介 说重构其实就是整理了代码,第一次自己手写写的很丑,然后看了书上写的,虽然和书上的思路不同但是整理后几乎一样漂亮 效果 整体代码如下 class Node { AdjNodes = new Set ...

  4. 不像JVM的JVM

    1.面向对象 面向对象的思想:将功能封装到对象中,通过对象去实现 面向对象的目的:将复杂的事情简单化,将以前过程中的执行者变成了指挥者且符合现在人们的思考习惯 面向对象的三大特征: 封装:将对象的实现 ...

  5. SQLSERVER 临时表和表变量到底有什么区别?

    一:背景 1. 讲故事 今天和大家聊一套面试中经常被问到的高频题,对,就是 临时表 和 表变量 这俩玩意,如果有朋友在面试中回答的不好,可以尝试看下这篇能不能帮你成功迈过. 二:到底有什么区别 1. ...

  6. 2021级《JAVA语言程序设计》上机考试试题6

    首先管理员页面 代码: <%@ page language="java" contentType="text/html; charset=UTF-8" p ...

  7. 安装redhat6.10 出现的问题

    安装redhat6.10 操作系统不定时重启情况说明   曾出现报错如下: 在UEFI模式下安装RHEL6.10,安装完毕后系统第一次重启无法进入操作系统,同时屏幕上出现错误提示: Invalid m ...

  8. JavaScript的闭包和作用域

    作用域相关 作用域的概念: 作用域是在运行时代码中的某些特定部分中变量,函数和对象的可访问性.换句话说,作用域决定了代码区块中变量和其他资源的可见性: 作用域的类型: 全局作用域: 最外层函数和在最外 ...

  9. Iceberg 数据治理及查询加速实践

    数据治理 Flink 实时写入 Iceberg 带来的问题 在实时数据源源不断经过 Flink 写入的 Iceberg 的过程中,Flink 通过定时的 Checkpoint 提交 snapshot ...

  10. 403. 青蛙过河 (Hard)

    问题描述 403. 青蛙过河 (Hard) 一只青蛙想要过河. 假定河流被等分为若干个单元格,并且在每一个单元格内都有可能放有一块石子(也有可能没有). 青蛙可以跳上石子,但是不可以跳入水中. 给你石 ...