#coding=utf-8
#1.数据预处理
import numpy as np #导入模块,numpy是扩展链接库
import pandas as pd
import tensorflow
import keras
from keras.utils import np_utils
np.random.seed(10) #设置seed可以产生的随机数据
from keras.datasets import mnist #导入模块,下载读取mnist数据
(x_train_image,y_train_label),\
(x_test_image,y_test_label)=mnist.load_data() #下载读取mnist数据
print('train data=',len(x_train_image))
print('test data=',len(x_test_image))
print('x_train_image:',x_train_image.shape)
print('y_train_label:',y_train_label.shape)
import matplotlib.pyplot as plt
def plot_image(image):
fig=plt.gcf()
fig.set_size_inches(2,2)
plt.imshow(image,cmap='binary')
plt.show()
y_train_label[0]
import matplotlib.pyplot as plt
def plot_image_labels_prediction(image,lables,prediction,idx,num=10):
fig=plt.gcf()
fig.set_size_inches(12,14)
if num>25:num=25
for i in range(0,num):
ax=plt.subplot(5,5,i+1)
ax.imshow(image[idx],cmap='binary')
title="lable="+str(lables[idx])
if len(prediction)>0:
title+=",predict="+str(prediction[idx])
ax.set_title(title,fontsize=10)
ax.set_xticks([]);ax.set_yticks([])
idx+=1
plt.show()
plot_image_labels_prediction(x_train_image,y_train_label,[],0,10)
plot_image_labels_prediction(x_test_image,y_test_label,[],0,10)
x_Train=x_train_image.reshape(60000,784).astype('float32') #以reshape转化成784个float
x_Test=x_test_image.reshape(10000,784).astype('float32')
x_Train_normalize=x_Train/255 #将features标准化
x_Test_normalize=x_Test/255
y_Train_OneHot=np_utils.to_categorical(y_train_label)#将训练数据和测试数据的label进行one-hot encoding转化
y_Test_OneHot=np_utils.to_categorical(y_test_label)
#2.建立模型
from keras.models import Sequential #可以通过Sequential模型传递一个layer的list来构造该模型,序惯模型是多个网络层的线性堆叠
from keras.layers import Dense #全连接层
from keras.layers import Dropout #避免过度拟合
model=Sequential()
#建立输入层、隐藏层
model.add(Dense(units=1000,
input_dim=784,
kernel_initializer='normal',
activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=1000,
kernel_initializer='normal',
activation='relu'))
model.add(Dropout(0.5))
#建立输出层
model.add(Dense(units=10,
kernel_initializer='normal',
activation='softmax'))
print(model.summary()) #查看模型的摘要
#3、进行训练
#对训练模型进行设置,损失函数、优化器、权值
model.compile(loss='categorical_crossentropy',
optimizer='adam',metrics=['accuracy'])
# 设置训练与验证数据比例,80%训练,20%测试,执行10个训练周期,每一个周期200个数据,显示训练过程2次
train_history=model.fit(x=x_Train_normalize,
y=y_Train_OneHot,validation_split=0.2,
epochs=10,batch_size=200,verbose=2)
#显示训练过程
import matplotlib.pyplot as plt
def show_train_history(train_history,train,validation):
plt.plot(train_history.history[train])
plt.plot(train_history.history[validation])
plt.title('Train History')
plt.ylabel(train)
plt.xlabel('Epoch')
plt.legend(['train','validation'],loc='upper left') #显示左上角标签
plt.show()
show_train_history(train_history,'acc','val_acc') #画出准确率评估结果
show_train_history(train_history,'loss','val_loss') #画出误差执行结果
#以测试数据评估模型准确率
scores=model.evaluate(x_Test_normalize,y_Test_OneHot) #创建变量存储评估后的准确率数据,(特征值,真实值)
print()
print('accuracy',scores[1])
#进行预测
prediction=model.predict_classes(x_Test)
prediction
plot_image_labels_prediction(x_test_image,y_test_label,prediction,idx=340)
#4、建立模型提高预测准确率
#建立混淆矩阵
import pandas as pd #pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的
pd.crosstab(y_test_label,prediction,
rownames=['label'],colnames=['predict'])
#建立真实值与预测值dataFrame
df=pd.DataFrame({'label':y_test_label,'predict':prediction})
df[:2]
df[(df.label==5)&(df.predict==3)]
plot_image_labels_prediction(x_test_image,y_test_label,prediction,idx=340,num=1)

keras—多层感知器识别手写数字算法程序的更多相关文章

  1. 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)

    主要内容: 1.基于多层感知器的mnist手写数字识别(代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  2. TensorFlow—多层感知器—MNIST手写数字识别

    1 import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import ...

  3. keras—多层感知器MLP—MNIST手写数字识别

    一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能:  其中隐含层节点数量(即神经细胞数量)计算 ...

  4. 4.2tensorflow多层感知器MLP识别手写数字最易懂实例代码

    自己开发了一个股票智能分析软件,功能很强大,需要的点击下面的链接获取: https://www.cnblogs.com/bclshuai/p/11380657.html 1.1  多层感知器MLP(m ...

  5. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  6. Keras mlp 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...

  7. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  8. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

  9. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

随机推荐

  1. HDU 1269 迷宫城堡(向量)(Tarjan模版题)

    迷宫城堡 Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others)Total Submis ...

  2. [原]System.IO.Path.Combine 路径合并

    使用 ILSpy 工具查看了 System.IO.Path 类中的 Combine 方法 对它的功能有点不放心,原方法实现如下: // System.IO.Path /// <summary&g ...

  3. 合并单元格/VBA

    ' 合并某一列中相同数据的单元格 Sub MergeColumns() Dim rowN As Integer Dim i, j, m, n As Integer Dim col As Integer ...

  4. 1115 Counting Nodes in a BST (30 分)

    1115 Counting Nodes in a BST (30 分) A Binary Search Tree (BST) is recursively defined as a binary tr ...

  5. php实现socket

    一.Socket 简介 1.socket只不过是一个数据结构. 2.使用这个socket数据结构去开始一个客户端和服务器之间的会话. 3.服务器是一直在监听准备产生一个新的会话.当一个客户端连接服务器 ...

  6. Scrapy爬取人人网

    Scrapy发送Post请求 防止爬虫被反主要有以下几个策略 动态设置User-Agent(随机切换User-Agent,模拟不同用户的浏览器信息) 禁用Cookies(也就是不启用cookies m ...

  7. RAC 11.2的新特性

    网格即插即用(GPnP) 网格即插即用帮助管理员来维护集群,以前增加或删除节点需要的一些手动操作的步骤现在可以由GPnP来自动实现. GPnP不是一个单独的概念,它依赖于以下特性:在一个XML配置文件 ...

  8. 19.XPath选择器

    1.extract():提取数据 2./text()     :获取节点内容文本 3./@href   :获取节点href属性 4. @         :获取属性名称 需要注意问题: 用定义的规则那 ...

  9. 数据分析利器之hive优化十大原则

    hive之于数据民工,就如同锄头之于农民伯伯.hive用的好,才能从地里(数据库)里挖出更多的数据来. 用过hive的朋友,我想或多或少都有类似的经历:一天下来,没跑几次hive,就到下班时间了. h ...

  10. sitemap

    sitemap对于网站就像是字典的索引目录,而这个目录的读者则是搜索引擎的爬虫.网站有了sitemap,有助于搜索引擎“了解”网站,这样会有助于站点的内容被收录. sitemap是一个由google主 ...