Keras 实例 MNIST
原文链接:http://www.one2know.cn/keras_mnist/
import numpy
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense # 稠密层
from keras.layers import Dropout # Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,Dropout层用于防止过拟合。
from keras.layers import Flatten # Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。
from keras.layers.convolutional import Conv2D # 二维卷积层,即对图像的空域卷积。
from keras.layers.convolutional import MaxPooling2D # 空间池化(也叫亚采样或下采样)降低了每个特征映射的维度,但是保留了最重要的信息
from keras.utils import np_utils
from keras import backend as K
K.set_image_dim_ordering('th') # 设置图像的维度顺序(‘tf’或‘th’)
# 当前的维度顺序如果为'th',则输入图片数据时的顺序为:channels,rows,cols,否则:rows,cols,channels
seed = 7
numpy.random.seed(seed)
#将数据reshape,CNN的输入是4维的张量(可看做多维的向量),第一维是样本规模,第二维是像素通道,第三维和第四维是长度和宽度。并将数值归一化和类别标签向量化。
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# reshape to be [samples][pixels][width][height]
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 1, 28, 28).astype('float32')
X_train = X_train / 255
X_test = X_test / 255
# 将标签转化成one-hot编码
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]
## 接下来构造CNN
# 第一层是卷积层。该层有32个feature map,或者叫滤波器,作为模型的输入层,接受[pixels][width][height]大小的输入数据。feature_map的大小是5*5,其输出接一个‘relu’激活函数。
# 下一层是pooling层,使用了MaxPooling,大小为2*2。
# 下一层是Dropout层,该层的作用相当于对参数进行正则化来防止模型过拟合。
# Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。
# 接下来是全连接层,有128个神经元,激活函数采用‘relu’。
# 最后一层是输出层,有10个神经元,每个神经元对应一个类别,输出值表示样本属于该类别的概率大小。
def baseline_model():
# create model
model = Sequential()
model.add(Conv2D(32, (5, 5), input_shape=(1, 28, 28), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
# Compile model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
# 建立模型
model = baseline_model()
# 训练模型
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)
# 模型概要打印
model.summary()
# 模型评估
scores = model.evaluate(X_test, y_test, verbose=0)
print("Baseline Error: %.2f%%" % (100-scores[1]*100))
输出:
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
- 68s - loss: 0.2247 - acc: 0.9358 - val_loss: 0.0776 - val_acc: 0.9754
Epoch 2/10
- 66s - loss: 0.0709 - acc: 0.9787 - val_loss: 0.0444 - val_acc: 0.9853
Epoch 3/10
- 67s - loss: 0.0511 - acc: 0.9843 - val_loss: 0.0435 - val_acc: 0.9855
Epoch 4/10
- 66s - loss: 0.0392 - acc: 0.9880 - val_loss: 0.0391 - val_acc: 0.9873
Epoch 5/10
- 66s - loss: 0.0325 - acc: 0.9898 - val_loss: 0.0341 - val_acc: 0.9893
Epoch 6/10
- 65s - loss: 0.0266 - acc: 0.9918 - val_loss: 0.0318 - val_acc: 0.9890
Epoch 7/10
- 65s - loss: 0.0221 - acc: 0.9929 - val_loss: 0.0348 - val_acc: 0.9886
Epoch 8/10
- 65s - loss: 0.0191 - acc: 0.9941 - val_loss: 0.0308 - val_acc: 0.9890
Epoch 9/10
- 66s - loss: 0.0153 - acc: 0.9951 - val_loss: 0.0325 - val_acc: 0.9897
Epoch 10/10
- 65s - loss: 0.0143 - acc: 0.9957 - val_loss: 0.0301 - val_acc: 0.9903
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 32, 24, 24) 832
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 12, 12) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 32, 12, 12) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 4608) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 589952
_________________________________________________________________
dense_2 (Dense) (None, 10) 1290
=================================================================
Total params: 592,074
Trainable params: 592,074
Non-trainable params: 0
_________________________________________________________________
Baseline Error: 0.97%
Keras 实例 MNIST的更多相关文章
- keras入门--Mnist手写体识别
介绍如何使用keras搭建一个多层感知机实现手写体识别及搭建一个神经网络最小的必备知识 import keras # 导入keras dir(keras) # 查看keras常用的模块 ['Input ...
- keras实现mnist数据集手写数字识别
一. Tensorflow环境的安装 这里我们只讲CPU版本,使用 Anaconda 进行安装 a.首先我们要安装 Anaconda 链接:https://pan.baidu.com/s/1AxdGi ...
- Keras实现MNIST分类
仅仅为了学习Keras的使用,使用一个四层的全连接网络对MNIST数据集进行分类,网络模型各层结点数为:784: 256: 128 : 10: 使用整体数据集的75%作为训练集,25%作为测试 ...
- 莫烦大大keras的Mnist手写识别(5)----自编码
一.步骤: 导入包和读取数据 数据预处理 编码层和解码层的建立 + 构建模型 编译模型 训练模型 测试模型[只用编码层来画图] 二.代码: 1.导入包和读取数据 #导入相关的包 import nump ...
- 莫烦大大keras学习Mnist识别(4)-----RNN
一.步骤: 导入包以及读取数据 设置参数 数据预处理 构建模型 编译模型 训练以及测试模型 二.代码: 1.导入包以及读取数据 #导入包 import numpy as np np.random.se ...
- 莫烦大大keras学习Mnist识别(3)-----CNN
一.步骤: 导入模块以及读取数据 数据预处理 构建模型 编译模型 训练模型 测试 二.代码: 导入模块以及读取数据 #导包 import numpy as np np.random.seed(1337 ...
- Keras载入mnist数据集出错问题解决方案
找到本地keras目录下的mnist.py文件 通常在这个目录下. ..\Anaconda3\Lib\site-packages\keras\datasets 下载mnist.npz文件到本地 下载链 ...
- keras实现mnist手写数字数据集的训练
网络:两层卷积,两层全连接,一层softmax 代码: import numpy as np from keras.utils import to_categorical from keras imp ...
- [机器学习] keras:MNIST手写数字体识别(DeepLearning 的 HelloWord程序)
深度学习界的Hello Word程序:MNIST手写数字体识别 learn from(仍然是李宏毅老师<机器学习>课程):http://speech.ee.ntu.edu.tw/~tlka ...
随机推荐
- 使用log4j进行日志管理
17.1.Log4j简介 作用: 1. 跟踪代码的运行轨迹. 2. 输出调试信息. 三大组成: 1. Logger类-生成日志. 2. Appender类-定义日志输出的目的地. 3. Layou ...
- 【iOS】this class is not key value coding-compliant for the key ...
一般此问题 都是由 interface build 与代码中 IBOutlet 的连接所引起的. 可能是在代码中对 IBOutlet 的名称进行了修改,导致 interface build 中的连接实 ...
- Python-默背单词
数据库单词: 默认单词 单词说明 innodb 事务,主键,外键,tree,表行锁 myisam 主要以插入读取和插入操作 memory 所有数据保存在内存中 ACID 原子性,一致性,隔离性,持 ...
- spark shuffle写操作三部曲之BypassMergeSortShuffleWriter
前言 再上一篇文章 spark shuffle的写操作之准备工作 中,主要介绍了 spark shuffle的准备工作,本篇文章主要介绍spark shuffle使用BypassMergeSortSh ...
- Java内部类超详细总结(含代码示例)
什么是内部类 什么是内部类? 顾名思义,就是将一个类的定义放在另一个类的内部. 概念很清楚,感觉很简单,其实关键在于这个内部类放置的位置,可以是一个类的作用域范围.一个方法的或是一个代码块的作用域范围 ...
- sqoop 密码别名模式 --password-alias
sqoop要使用别名模式隐藏密码 1.首先使用命令创建别名 hadoop credential create xiaopengfei -provider jceks://hdfs/user/pass ...
- redis缓存介绍以及常见问题浅析
# 没缓存的日子: 对于web来说,是用户量和访问量支持项目技术的更迭和前进.随着服务用户提升.可能会出现一下的一些状况: 页面并发量和访问量并不多,mysql足以支撑自己逻辑业务的发展.那么其实可以 ...
- 0x33 同余
目录 定义 同余类与剩余系 费马小定理 欧拉定理 证明: 欧拉定理的推论 证明: 应用: 定义 若整数 $a$ 和整数 $b$ 除以正整数 $m$ 的余数相等,则称 $a,b$ 模 $m$ 同余,记为 ...
- 深入浅出Apriori关联分析算法(一)
在美国有这样一家奇怪的超市,它将啤酒与尿布这样两个奇怪的东西放在一起进行销售,并且最终让啤酒与尿布这两个看起来没有关联的东西的销量双双增加.这家超市的名字叫做沃尔玛. 你会不会觉得有些不可思议?虽然事 ...
- Javascript实现简单地发布订阅模式
不论是在程序世界里还是现实生活中,发布—订阅模式的应用都非常广泛.我们先看一下现实中的例子. 小明最近看上了一套房子,到了售楼处之后才被告知,该楼盘的房子早已售罄.好在售楼MM告诉小明,不久后还有一些 ...