# 1. 数据预处理

import keras

from keras import backend as K
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D num_classes = 10
img_rows, img_cols = 28, 28 # 通过Keras封装好的API加载MNIST数据。其中trainX就是一个60000 * 28 * 28的数组,
# trainY是每一张图片对应的数字。
(trainX, trainY), (testX, testY) = mnist.load_data() # 根据对图像编码的格式要求来设置输入层的格式。
if K.image_data_format() == 'channels_first':
trainX = trainX.reshape(trainX.shape[0], 1, img_rows, img_cols)
testX = testX.reshape(testX.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
trainX = trainX.reshape(trainX.shape[0], img_rows, img_cols, 1)
testX = testX.reshape(testX.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1) trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.0
testX /= 255.0 # 将标准答案转化为需要的格式(one-hot编码)。
trainY = keras.utils.to_categorical(trainY, num_classes)
testY = keras.utils.to_categorical(testY, num_classes)

# 2. 通过Keras的API定义卷机神经网络。
# 使用Keras API定义模型。
model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dense(num_classes, activation='softmax')) # 定义损失函数、优化函数和评测方法。
model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.SGD(),metrics=['accuracy'])
# 3. 通过Keras的API训练模型并计算在测试数据上的准确率。
model.fit(trainX, trainY,batch_size=128,epochs=10,validation_data=(testX, testY)) # 在测试数据上计算准确率。
score = model.evaluate(testX, testY)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

吴裕雄--天生自然TensorFlow高层封装:Keras-CNN的更多相关文章

  1. 吴裕雄--天生自然TensorFlow高层封装:Keras-TensorFlow API

    # 1. 模型定义. import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist_ ...

  2. 吴裕雄--天生自然TensorFlow高层封装:Keras-多输入输出

    # 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...

  3. 吴裕雄--天生自然TensorFlow高层封装:Keras-返回值

    # 1. 数据预处理. import keras from keras.models import Model from keras.datasets import mnist from keras. ...

  4. 吴裕雄--天生自然TensorFlow高层封装:Estimator-自定义模型

    # 1. 自定义模型并训练. import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist i ...

  5. 吴裕雄--天生自然TensorFlow高层封装:Estimator-DNNClassifier

    # 1. 模型定义. import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...

  6. 吴裕雄--天生自然TensorFlow高层封装:Keras-RNN

    # 1. 数据预处理. from keras.layers import LSTM from keras.datasets import imdb from keras.models import S ...

  7. 吴裕雄--天生自然TensorFlow高层封装:解决ImportError: cannot import name 'tf_utils'

    将原来版本的keras卸载了,再安装2.1.5版本的keras就可以了.

  8. 吴裕雄--天生自然TensorFlow高层封装:解决ValueError: Invalid backend. Missing required entry : placeholder

    找到对应的keras配置文件keras.json 将里面的内容修改为以下就可以了

  9. 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型

    # 1. 通过TensorFlow-Slim定义卷机神经网络 import numpy as np import tensorflow as tf import tensorflow.contrib. ...

随机推荐

  1. 15 ~ express ~ 用户数据分页原理和实现

    一,在后台路由 /router/admin.js 中 1,限制获取的数据条数 : User.find().limit(Number) 2,忽略数据的前(Number)条数据 : skip(Number ...

  2. Spring注解@ResponseBody

    @Responsebody 将内容或对象作为http响应正文返回,并调用适合HttpMessageConverter的Adapter转换对象,写入输出流. 写在方法上面表示:表示该方法的返回结果直接写 ...

  3. sql server ------创建本地数据库 SQL Server 排序规则

    sql server完整复制数据库 sql server导入导出方法 SQL Server 排序规则

  4. Day 2:线程与进程系列问题(二)

    补充: 线程的创建方式二: 1.自定义一个实现Runnable接口的类 2.实现Runnable接口中的run方法把自定义线程的任务写在run方法中 3.创建实现Runnable接口的对象 4.创建T ...

  5. 寒假day22

    今天解决了标签模块的一些错误,同时美化了界面

  6. Spring创建Bean的顺序

    一直对Spring创建bean的顺序很好奇,现在总算有时间写个代码测试一下.不想看过程的小伙伴可以直接看结论 目录结构: 其中:bean4.bean5包下的class没有注解@Component,测试 ...

  7. git登录账号密码错误remote: Incorrect username or password (access token)

    git提交时弹框让输入用户和密码,不小心输入错误了 再次提交 一直就提示  remote: Incorrect username or password 错误了,也不弹框要重新输入 解决方法 win1 ...

  8. 进度3_家庭记账本App_Fragment使用SQLite实现简单存储及查询

    AddFragment.java: package com.example.familybooks; import android.content.ContentValues; import andr ...

  9. BZOJ 4084 [Sdoi2015]双旋转字符串

    题解:hash 至今不会unsigned long long 的输出 把B扔进map 找A[mid+1][lenA]在A[1][mid]中的位置 把A[1][mid]贴两遍(套路) 枚举A[mid+1 ...

  10. Python笔记_第五篇_Python数据分析基础教程_NumPy基础

    1. NumPy的基础使用涵盖如下内容: 数据类型 数组类型 类型转换 创建数组 数组索引 数组切片 改变维度 2. NumPy数组对象: NumPy中的ndarray是一个多维数组对象,该兑现共有两 ...