转自:https://morvanzhou.github.io/tutorials/machine-learning/keras/2-2-classifier/#测试模型

下载数据

# download the mnist to the path '~/.keras/datasets/' if it is the first time to be called
# X shape (60,000 28x28), y shape (10,000, )
(X_train, y_train), (X_test, y_test) = mnist.load_data()

data预处理:

X_train = X_train.reshape(X_train.shape[0], -1) / 255.   # normalize
X_test = X_test.reshape(X_test.shape[0], -1) / 255.      # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)

导入包:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./", one_hot=True)
X_train=mnist.train.images
Y_train=mnist.train.labels
X_test=mnist.test.images
Y_test=mnist.test.labels

因为(X_train, y_train), (X_test, y_test) = mnist.load_data()需从网上下载数据,由于网络限制,下载失败。

可以先在官网yann.lecun.com/exdb/mnist/上下载四个数据(train-images-idx3-ubyte.gz、train-labels-idx1-ubyte.gz、t10k-images-idx3-ubyte.gz、t10k-labels-idx1-ubyte.gz

在当前目录,不要解压!

#input_data.py该模块在tensorflow.examples.tutorials.mnist下,直接加载来读取上面四个压缩包。

#四个压缩包形式为特殊形式。非图片和标签,要解析。

from tensorflow.examples.tutorials.mnist import input_data

#加载数据路径为"./",为当前路径,自动加载数据,用one-hot方式处理好数据。

#read_data_sets是input_data.py里面的一个函数,主要是将数据解压之后,放到对应的位置。 第一个参数为路径,写"./"表示当前路径,其会判断该路径下有没有数据,没有的话会自动下载数据。

mnist = input_data.read_data_sets("./", one_hot=True)  

相关的包:

model.Sequential():用来一层一层的去建立神经层。

layers.Dense,表示这个神经层是全连接层。

layers.Activation,激励函数

optimizers.RMSprop,优化器采用RMSprop,加速神经网络训练方法。

Keras工作流程:

  1. 定义训练数据:输入张量和目标张量
  2. 定义层组成的网络(或模型),将输入映射到目标
  3. 配置学习过程:选择损失函数、优化器和需要监控的指标
  4. 调用模型的fit方法在训练数据上进行迭代

代码:

  1. import numpy as np
  2. np.random.seed(1337) # for reproducibility
  3. from keras.datasets import mnist
    from keras.models import Sequential
  4. from keras.layers import Dense, Activation
  5. from keras.optimizers import RMSprop
  6. #读取数据,其中,X_train为55000*784,Y_train为55000*10,X_test为10000*784,Y_test大小为10000*10.
  7. from tensorflow.examples.tutorials.mnist import input_data
  8. mnist = input_data.read_data_sets("./", one_hot=True)
  9. X_train=mnist.train.images
  10. Y_train=mnist.train.labels
  11. X_test=mnist.test.images
  12. Y_test=mnist.test.labels
  13.  
  14. #建立神经网络模型,一共两层,第一层输入784个变量,输出为32,激活函数为relu,第二层输入是上层的输出32,输出为10,激活函数为softmax。
  15. model = Sequential([
  16. Dense(32, input_dim=784),
  17. Activation('relu'),
  18. Dense(10),
  19. Activation('softmax'),
  20. ])
  21. #采用RMSprop来求解模型,设学习率lr为0.001,以及别的参数。
  22. rmsprop = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
  23. #激活模型,优化器为rmsprop,损失函数为交叉熵,metric,里面可以放入需要计算的,比如cost、accuracy、score等
  24. model.compile(optimizer=rmsprop,
  25. loss='categorical_crossentropy',
  26. metrics=['accuracy'])
  27. #训练网络,用fit函数,导入数据,训练次数为20,每批处理32个
  28. model.fit(X_train, Y_train, nb_epoch=20, batch_size=32)
  29. #测试模型
  30. print('\nTesting ------------')
  31. # Evaluate the model with the metrics we defined earlier
  32. loss, accuracy = model.evaluate(X_test, Y_test)
  33.  
  34. print('test loss: ', loss)
  35. print('test accuracy: ', accuracy)

结果:

 

Keras手写识别例子(1)----softmax的更多相关文章

  1. (五) Keras Adam优化器以及CNN应用于手写识别

    视频学习来源 https://www.bilibili.com/video/av40787141?from=search&seid=17003307842787199553 笔记 Adam,常 ...

  2. Haskell手撸Softmax回归实现MNIST手写识别

    Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...

  3. 李宏毅 Keras手写数字集识别(优化篇)

    在之前的一章中我们讲到的keras手写数字集的识别中,所使用的loss function为‘mse’,即均方差.那我们如何才能知道所得出的结果是不是overfitting?我们通过运行结果中的trai ...

  4. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  5. TensorFlow MNIST(手写识别 softmax)实例运行

    TensorFlow MNIST(手写识别 softmax)实例运行 首先要有编译环境,并且已经正确的编译安装,关于环境配置参考:http://www.cnblogs.com/dyufei/p/802 ...

  6. TensorFlow 入门之手写识别(MNIST) softmax算法 二

    TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  7. 微软手写识别模块sdk及delphi接口例子

    http://download.csdn.net/download/coolstar1204/2008061 微软手写识别模块sdk及delphi接口例子

  8. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  9. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

随机推荐

  1. AngularJS:一行JS代码实现控件验证效果

    如上图所示,我们需要实现如下这些验证功能: 控件都是必输控件 都需要控制最大长度 第一次打开页面,控件不能显示为错误状态 输入内容再清空后,必输控件需要显示为错误状态 只有所有输入合法后,发布按钮才能 ...

  2. HDU 5239 上海大都会 D题(线段树+数论)

    打表,发现规律是存在一定次数(较小)后,会出现a=(a*a)%p.可以明显地发现本题与线段树有关.设置标记flag,记录本段内的数是否均已a=a*a%p.若是,则不需更新,否则更新有叶子结点,再pus ...

  3. java 执行可执行文件时提示“could not find or load main class ”的问题

  4. clCreateCommandQueue': was declared deprecated

    今天在配置opencl的开发环境.測试用例时,用的是intel的sdk开发包.遇到了这个问题: clCreateCommandQueue': was declared deprecated 也就是说这 ...

  5. ExtJs4.1布局具体解释

    Border布局: Ext.onReady(function(){     Ext.QuickTips.init();     Ext.create('Ext.container.Viewport', ...

  6. TiDB(1): server測试安装

    本文的原文连接是: http://blog.csdn.net/freewebsys/article/details/50600352 未经博主同意不得转载. 博主地址是:http://blog.csd ...

  7. ASP.NET MVC 认证模块报错:“System.Configuration.Provider.ProviderException: 未启用角色管理器功能“

    新建MVC4项目的时候 选 Internet 应用程序的话,出来的示例项目就自带了默认的登录认证等功能.如果选空或者基本,就没有. 如果没有,现在又想加进去,怎么办呢? 抄啊.将示例项目的代码原原本本 ...

  8. jQuery总结03

    1 控制网页元素属性和样式的 jQuery 方法有哪些? 2 利用 jQuery 插入网页节点的方法有哪些? 3 jQuery 中绑定事件是什么,如何解除绑定? 4 jQuery 中的动画效果包括哪些 ...

  9. 关于QT版本的安装配置的一些困惑

    大概是之前安装和使用QT太顺利了,什么都没注意就开始使用了.在使用VS2012开发Qt5.31的程序一段时间以后,虽然好用,但是发现其编译的程序不能在XP上使用,要打补丁才行.不仅VS2012本身要打 ...

  10. fixed和absolute

    fixed是相对于浏览器窗口固定 absolute是相对于整体网页固定.(整体网页包括所有的内容,包含右侧滑动条滑动所能看到的内容)