本项目参考:

https://www.bilibili.com/video/av31500120?t=4657

训练代码

 # coding: utf-8
# Learning from Mofan and Mike G
# Recreated by Paprikatree
# Convolution NN Train import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Convolution2D, Activation, MaxPool2D, Flatten, Dense
from keras.optimizers import Adam
from keras.models import load_model nb_class = 10
nb_epoch = 4
batchsize = 128 '''
1st,准备参数
X_train: (0,255) --> (0,1) CNN中似乎没有必要?cnn自动转了吗?
设置时间函数测试一下两者对比。
小技巧:X_train /= 255.0 就可不用转换成浮点了???
'''
# Preparing your data mnist. MAC /.keras/datasets linux home ./keras/datasets
(X_train, Y_train), (X_test, Y_test) = mnist.load_data() # setup data shape
# (-1, 28, 28, 1) -1表示有默认个数据集,28*28是像素,1是1个通道
X_train = X_train.reshape(-1, 28, 28, 1) # tensorflow-channel last,while theano-channel first
X_test = X_test.reshape(-1, 28, 28, 1) X_train = X_train/255.000
X_test = X_test/255.000 # One-hot 6 --> [0,0,0,0,0,1,0,0,0]
Y_train = np_utils.to_categorical(Y_train, nb_class)
Y_test = np_utils.to_categorical(Y_test, nb_class) '''
2nd,设置模型
''' # setup model
model = Sequential() # 1st convolution layer # 滤波器要在28x28的图上横着走32次
model.add(Convolution2D(
filters=32, # 此处把filters写成了filter,找了半天。囧
kernel_size=[5, 5], # 滤波器是5x5大小的,可以是list列表,也可以是tuple元祖
padding='same', # padding也是一个窗口模式
input_shape=(28, 28, 1) # 定义输入的数据,必须是元组
))
model.add(Activation('relu'))
model.add(MaxPool2D(
pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
strides=(2, 2), # 相当于把图片缩小了。
padding="same",
)) # 2nd Conv2D layer
model.add(Convolution2D(
filters=64,
kernel_size=(5, 5),
padding='same',
))
model.add(Activation('relu'))
model.add(MaxPool2D(
pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
strides=(2, 2), # 相当于把图片缩小了。
padding="same",
)) # 讨论,卷积层数和最终结果关系。 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
model.add(Flatten()) # 把卷积层里面的全部转换层一维数组
model.add(Dense(1024)) # Dense is output
model.add(Activation('relu')) # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
# 把卷积层里面的全部转换层一维数组
model.add(Dense(256)) # Dense is output
model.add(Activation('tanh')) # 2nd Fully connected Dense
model.add(Dense(10))
model.add(Activation('softmax')) '''
3rd 定义参数
'''
# Define Optimizer and setup Param
adam = Adam(lr=0.0001) # Adam实例化 # compile model
model.compile(
optimizer=adam, # optimizer='Adam'也是可以的,且默认lr=0.001,此处已经实例化为adam
loss='categorical_crossentropy',
metrics=['accuracy'],
) # Run network
model.fit(x=X_train, # 更多参数可以查看fit函数,alt+鼠标左键单击fit
y=Y_train,
epochs=nb_epoch,
batch_size=batchsize, # p=parameter, batch_size; v=var, batch size
verbose=1, # 显示模式
validation_data=(X_test, Y_test)
)
model.save('model_name.h5')
# evaluation = model.evaluate(X_test, Y_test) 现在用model.fit(validation_data)
# print(evaluation) 效果一样

测试代码:

 # coding: utf-8
# Learning from Mofan and Mike G
# Recreated by Paprikatree
# Convolution NN Predict import numpy as np
from keras.models import load_model # ??
import matplotlib.pyplot as plt
import matplotlib.image as processimage # load trained model
model = load_model('model_name.h5') # 已经训练好了的模型,在根目录下,默认为model_name.h5 # 写一个来预测的类
class MainPredictImg(object): def __init__(self):
pass def pred(self, filename):
pred_img = processimage.imread(filename)
pred_img = np.array(pred_img)
pred_img = pred_img.reshape(-1, 28, 28, 1)
prediction = model.predict(pred_img)
final_prediction = [result.argmax() for result in prediction][0]
a = 0
for i in prediction[0]:
print(a)
print('Percent:{:.30%}'.format(i))
a = a+1
return final_prediction def main():
predict = MainPredictImg()
res = predict.pred('4.png')
print("your number is:-->", res) if __name__ == '__main__':
main()

keras02 - hello convolution neural network 搭建第一个卷积神经网络的更多相关文章

  1. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  2. Convolution Neural Network (CNN) 原理与实现

    本文结合Deep learning的一个应用,Convolution Neural Network 进行一些基本应用,参考Lecun的Document 0.1进行部分拓展,与结果展示(in pytho ...

  3. Deeplearning - Overview of Convolution Neural Network

    Finally pass all the Deeplearning.ai courses in March! I highly recommend it! If you already know th ...

  4. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  5. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.3

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3 http://blog.csdn.net/sunbow0 ...

  6. 深度学习:卷积神经网络(convolution neural network)

    (一)卷积神经网络 卷积神经网络最早是由Lecun在1998年提出的. 卷积神经网络通畅使用的三个基本概念为: 1.局部视觉域: 2.权值共享: 3.池化操作. 在卷积神经网络中,局部接受域表明输入图 ...

  7. Recurrent Neural Network系列1--RNN(循环神经网络)概述

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  8. 【面向代码】学习 Deep Learning(三)Convolution Neural Network(CNN)

    ========================================================================================== 最近一直在看Dee ...

  9. TensorFlow从入门到理解(三):你的第一个卷积神经网络(CNN)

    运行代码: from __future__ import print_function import tensorflow as tf from tensorflow.examples.tutoria ...

随机推荐

  1. 每日分享!~ JavaScript(拖拽事件)

    浏览器的拖拉事件 拖拉(drag)指的是,用户在某个对象上按下鼠标键不放,拖动它到另一个位置,然后释放鼠标键,将该对象放在那里. 拖拉的对象有好几种,包括元素节点.图片.链接.选中的文字等等.在网页中 ...

  2. sqlserver数据库备份时出现3241问题

    工作中需要将生产上的数据库备份到测试数据库一份,然后同步生产环境进行测试.但是在将数据库还原的过程中,遇到了下面的问题: 说是,介质簇结构不正确,猜测应该是sqlserver的版本不一致的问题,然后查 ...

  3. 学习ASP.NET Core Razor 编程系列十二——在页面中增加校验

    学习ASP.NET Core Razor 编程系列目录 学习ASP.NET Core Razor 编程系列一 学习ASP.NET Core Razor 编程系列二——添加一个实体 学习ASP.NET ...

  4. 生产线平衡问题的+Leapms线性规划方法

    知识点 第一类生产线平衡问题,第二类生产线平衡问题 整数线性规划模型,+Leapms模型,直接求解,CPLEX求解 装配生产线平衡问题 (The Assembly Line Balancing Pro ...

  5. WinDbg调试C#技巧,解决CPU过高、死锁、内存爆满

    软件安装 安装问题:执行 .loadby sos clr 命令无效 解决办法: .load C:\Windows\Microsoft.NET\Framework64\v4.0.30319\SOS.dl ...

  6. 版本控制工具——Git常用操作(上)

    本文由云+社区发表 作者:工程师小熊 摘要:用了很久的Git和svn,由于总是眼高手低,没能静下心来写这些程序员日常开发最常用的知识点.现在准备开一个专题,专门来总结一下版本控制工具,让我们从git开 ...

  7. [JavaScript] 函数节流(throttle)和函数防抖(debounce)

    js 的函数节流(throttle)和函数防抖(debounce)概述 函数防抖(debounce) 一个事件频繁触发,但是我们不想让他触发的这么频繁,于是我们就设置一个定时器让这个事件在 xxx 秒 ...

  8. Spring Cloud Alibaba基础教程:Sentinel使用Apollo存储规则

    上一篇我们介绍了如何通过Nacos的配置功能来存储限流规则.Apollo是国内用户非常多的配置中心,所以,今天我们继续说说Spring Cloud Alibaba Sentinel中如何将流控规则存储 ...

  9. Dynamics 365-为什么查到的Record的Id是Guid初始值

    通过代码查询CRM数据,这个是开发经常会碰到的情况,获取返回的EntityCollection之后,我们会拿Entity.Id做进一步操作.笔者最近碰到的情况,是Entity.Id是个初始值.先上一段 ...

  10. Linux通过NFS实现文件共享

    在项目生产环境我们经常需要实现文件共享,传统的常见方案是通过NFS,实现服务器之间共享某一块磁盘,通过网络传输将分散的文件集中存储在一块指定的共享磁盘,实现基本的文件共享.实现这种方案,分服务端和客户 ...