转自:https://morvanzhou.github.io/tutorials/machine-learning/keras/2-2-classifier/#测试模型

下载数据

# download the mnist to the path '~/.keras/datasets/' if it is the first time to be called
# X shape (60,000 28x28), y shape (10,000, )
(X_train, y_train), (X_test, y_test) = mnist.load_data()

data预处理:

X_train = X_train.reshape(X_train.shape[0], -1) / 255.   # normalize
X_test = X_test.reshape(X_test.shape[0], -1) / 255.      # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)

导入包:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./", one_hot=True)
X_train=mnist.train.images
Y_train=mnist.train.labels
X_test=mnist.test.images
Y_test=mnist.test.labels

因为(X_train, y_train), (X_test, y_test) = mnist.load_data()需从网上下载数据,由于网络限制,下载失败。

可以先在官网yann.lecun.com/exdb/mnist/上下载四个数据(train-images-idx3-ubyte.gz、train-labels-idx1-ubyte.gz、t10k-images-idx3-ubyte.gz、t10k-labels-idx1-ubyte.gz

在当前目录,不要解压!

#input_data.py该模块在tensorflow.examples.tutorials.mnist下,直接加载来读取上面四个压缩包。

#四个压缩包形式为特殊形式。非图片和标签,要解析。

from tensorflow.examples.tutorials.mnist import input_data

#加载数据路径为"./",为当前路径,自动加载数据,用one-hot方式处理好数据。

#read_data_sets是input_data.py里面的一个函数,主要是将数据解压之后,放到对应的位置。 第一个参数为路径,写"./"表示当前路径,其会判断该路径下有没有数据,没有的话会自动下载数据。

mnist = input_data.read_data_sets("./", one_hot=True)  

相关的包:

model.Sequential():用来一层一层的去建立神经层。

layers.Dense,表示这个神经层是全连接层。

layers.Activation,激励函数

optimizers.RMSprop,优化器采用RMSprop,加速神经网络训练方法。

Keras工作流程:

  1. 定义训练数据:输入张量和目标张量
  2. 定义层组成的网络(或模型),将输入映射到目标
  3. 配置学习过程:选择损失函数、优化器和需要监控的指标
  4. 调用模型的fit方法在训练数据上进行迭代

代码:

import numpy as np
np.random.seed(1337) # for reproducibility
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.optimizers import RMSprop
#读取数据,其中,X_train为55000*784,Y_train为55000*10,X_test为10000*784,Y_test大小为10000*10.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./", one_hot=True)
X_train=mnist.train.images
Y_train=mnist.train.labels
X_test=mnist.test.images
Y_test=mnist.test.labels #建立神经网络模型,一共两层,第一层输入784个变量,输出为32,激活函数为relu,第二层输入是上层的输出32,输出为10,激活函数为softmax。
model = Sequential([
Dense(32, input_dim=784),
Activation('relu'),
Dense(10),
Activation('softmax'),
])
#采用RMSprop来求解模型,设学习率lr为0.001,以及别的参数。
rmsprop = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
#激活模型,优化器为rmsprop,损失函数为交叉熵,metric,里面可以放入需要计算的,比如cost、accuracy、score等
model.compile(optimizer=rmsprop,
loss='categorical_crossentropy',
metrics=['accuracy'])
#训练网络,用fit函数,导入数据,训练次数为20,每批处理32个
model.fit(X_train, Y_train, nb_epoch=20, batch_size=32)
#测试模型
print('\nTesting ------------')
# Evaluate the model with the metrics we defined earlier
loss, accuracy = model.evaluate(X_test, Y_test) print('test loss: ', loss)
print('test accuracy: ', accuracy)

结果:

 

Keras手写识别例子(1)----softmax的更多相关文章

  1. (五) Keras Adam优化器以及CNN应用于手写识别

    视频学习来源 https://www.bilibili.com/video/av40787141?from=search&seid=17003307842787199553 笔记 Adam,常 ...

  2. Haskell手撸Softmax回归实现MNIST手写识别

    Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...

  3. 李宏毅 Keras手写数字集识别(优化篇)

    在之前的一章中我们讲到的keras手写数字集的识别中,所使用的loss function为‘mse’,即均方差.那我们如何才能知道所得出的结果是不是overfitting?我们通过运行结果中的trai ...

  4. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  5. TensorFlow MNIST(手写识别 softmax)实例运行

    TensorFlow MNIST(手写识别 softmax)实例运行 首先要有编译环境,并且已经正确的编译安装,关于环境配置参考:http://www.cnblogs.com/dyufei/p/802 ...

  6. TensorFlow 入门之手写识别(MNIST) softmax算法 二

    TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  7. 微软手写识别模块sdk及delphi接口例子

    http://download.csdn.net/download/coolstar1204/2008061 微软手写识别模块sdk及delphi接口例子

  8. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  9. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

随机推荐

  1. android程序在调试时出现了套接字异常“java.net.SocketException: Permission denied”该如何解决

    Socket不能对外连接,错误不会被报出,调试的时候,能看到Exception, 一般是抛出 java.net.socketexception permission denied这个异常.只要你的程序 ...

  2. centos7 安装vsftpd的步骤

    感觉非常坑,依照网上说的没一个都測试了,可一直都报错. 不断的又一次安装不下10次,最后一次最终測试出了正确的方法. #官网配置说明## https://security.appspot.com/vs ...

  3. jvm的运行模式 client和 server两种

    jvm的运行模式 client和 server两种 学习了:https://www.cnblogs.com/fsjohnhuang/p/4270505.html 在jdk 9的情况下,好像没有clie ...

  4. [Javascript] AbortController to cancel the fetch request

    We are able to cancel the fetch request by using AbortController with RxJS Observable. return Observ ...

  5. 605B. Lazy Student(codeforces Round 335)

    B. Lazy Student time limit per test 2 seconds memory limit per test 256 megabytes input standard inp ...

  6. 用 JSQMessagesViewController 创建一个 iOS 聊天 App - 第 2 部分

    原文链接 : Create an iOS Chat App using JSQMessagesViewController – Part 2 原文作者 : Mariusz Wisniewski 译者 ...

  7. dp状态压缩

    dp状态压缩 动态规划本来就很抽象,状态的设定和状态的转移都不好把握,而状态压缩的动态规划解决的就是那种状态很多,不容易用一般的方法表示的动态规划问题,这个就更加的难于把握了.难点在于以下几个方面:状 ...

  8. 认证与授权协议对比:OAuth2、OpenID、SMAL

    认证授权是目前大多数系统都必须要实现都功能,认证就是验证用户都身份,授权就是验证身份后对受限资源的访问控制.最开始是单个平台要做,后来在互联网时代到来,一个账户可登陆多个平台,然后是各种开放平台账户共 ...

  9. 第2章 安装Nodejs 2-2 Nodejs版本常识

  10. Rancher 2:添加 NFS client provisioner 动态提供 Kubernetes 后端存储卷

    一.前提说明 1.说明: NFS client provisioner 利用 NFS Server 给 Kubernetes 作为持久存储的后端,并且动态提供PV. 默认 rancher 2 的存储类 ...