keras_实现cnn_手写数字识别
# conding:utf-8
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import SGD
from keras import optimizers
import matplotlib.pyplot as plt
import pandas as pd
from keras.models import load_model # 数据准备
x_train =np.zeros((4500,28,28,1))
x_test =np.zeros((500,28,28,1))
y_train=[]
y_test=[] for i in range(0,10):
for j in range(1,501):
if j < 451: #将数据保存到训练数据中
x_train[(j-1)+(i*450),:,:,0]=plt.imread('./data/%d/%d_%d.bmp'%(i,i,j)) #reshape 可以降维也就是矩阵变化
y_train.append(i) #append 是读进来的数据进行存储的意思
else: #保存到预测数据中
x_test[(i*50)+(j-452),:,:,0]=plt.imread('./data/%d/%d_%d.bmp'%(i,i,j))
y_test.append(i)
y_t = np.array(y_test).reshape(-1,1)
print(x_train.shape)
# x_train = np.array(x_train).reshape(450,28,28,1)
y_train = np.array(pd.get_dummies(y_train))
print(y_train.shape)
# x_test = np.array(x_test).reshape(50,28,28,1)
y_test = np.array(pd.get_dummies(y_test)) # 模型建立 model = Sequential()
# 第一层:
model.add(Conv2D(32,(3,3),input_shape=(28,28,1),activation='relu',padding='valid'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.3)) #第二层:
# model.add(Conv2D(64,(5,5),activation='relu',padding='same',data_format='channels_first'))
# model.add(MaxPooling2D(pool_size=(2,2)))
# model.add(Dropout(0.25))
model.add(Conv2D(32,(3,3),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
# 2、全连接层和输出层:
model.add(Flatten())
model.add(Dense(128,activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10,activation='softmax')) model.summary()
model.compile(loss='categorical_crossentropy',#,'binary_crossentropy'
optimizer=optimizers.Adadelta(lr=0.2, rho=0.95, epsilon=1e-06),#,'Adadelta'
metrics=['accuracy']) # 模型训练
model.fit(x_train,y_train,batch_size=128,epochs=35)
y_y = model.predict(x_test)
score = model.evaluate(x_test, y_test, verbose=0)
# 保存模型
# model.save('test/my_model.h5')
print(score)
# 模型导入
# model = load_model('test/my_model.h5')
# y_y = model.predict(x_test)
# y_s = np.argmax(y_y,axis=1).reshape(-1,1)
# score_pred = len((y_t-y_s)[(y_t-y_s)==0])/len(y_t)
# print('准确率:',score_pred)
# plt.figure(figsize=(12,6))
# plt.scatter(list(range(len(y_s))),y_s,c=y_t)
# xlabel = ['数字0','数字1','数字2','数组3','数字4','数字5','数字6','数字7','数字8','数字9']
# plt.yticks(range(10),xlabel)
# plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体为SimHei显示中文
# plt.rcParams['axes.unicode_minus'] = False # 设置正常显示符号
# plt.show()
keras_实现cnn_手写数字识别的更多相关文章
- C#中调用Matlab人工神经网络算法实现手写数字识别
手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化 投影 矩阵 目标定位 Matlab 手写数字图像识别简介: 手写 ...
- CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 【深度学习系列】PaddlePaddle之手写数字识别
上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 机器学习(二)-kNN手写数字识别
一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...
- 利用神经网络算法的C#手写数字识别
欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)
通过: 手写数字识别 ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别 ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...
随机推荐
- iOS-读写plist文件
读写plist文件 问题,我有一个plist文件,表示56个民族的,但是里面保存的字典,我想转换成一个数组 好的,那么就先遍历这个plist,然后将结果保存到一个数组中,这里出现的一个问题就是C语言字 ...
- Android Open Source Projects(汇总与整理)
Android Open Source Projects 目前包括: Android开源项目第一篇——个性化控件(View)篇 包括ListView.ActionBar.Menu.ViewPager ...
- 【题解搬运】PAT_A1016 Phone Bills
从我原来的博客上搬运.原先blog作废. 题目 A long-distance telephone company charges its customers by the following rul ...
- Selenium Grid 环境搭建 碰到的unable to access server
1. Slenenium Grid的环境部署, 前提条件: JDK,JRE都已经安装, selenium的standalone jar包放在磁盘 执行如下命令,报错: 2. 在cmd窗口里切换到jar ...
- struts2 action中获取request session application的方法
共四种方式: 其中前两种得到的是Map<String,Object> 后两种得到的才是真正的request对象 而Map就是把request对象中的属性取出做成了键值对而已. [方法一] ...
- WebStorm强大的调试JavaScript功能(转载)
一.JavaScript的调试 目前火狐和Chrome都具备调试JavaScript的功能,而且还是相当的强大.如果纯粹是用浏览器来进行js调试的话,我比较喜欢用火狐.火狐可以安装各种插件,真的是非常 ...
- C++中范围for语句
如果想对string对象中的每个字符做点什么操作,目前最好的办法是使用C++11新标准提供的一种语句:范围for(range for)语句. 示例代码: #include<iostream> ...
- 【PHP】- include、require、include_once 和 require_once的区别
1.include:会将指定的档案读入并且执行里面的程序. 被导入的档案中的程序代码都会被执行,而且这些程序在执行的时候会拥有和源文件中呼叫到 include() 函数的位置相同的变量范围( ...
- Chromium之文件类型
.grp: Generate your project. 是由Json(JavaScript Object Notation)(or Python?)来解析,根据环境(OS,Compiler..)来生 ...
- requests快速入门
Requests 是唯一的一个非转基因的 Python HTTP 库,人类可以安全享用. 警告:非专业使用其他 HTTP 库会导致危险的副作用,包括:安全缺陷症.冗余代码症.重新发明轮子症.啃文档症. ...