神经网络MPLClassifier分类
代码:
- # -*- coding: utf-8 -*-
- """
- Created on Fri Aug 24 14:38:56 2018
- @author: zhen
- """
- import gzip
- import pickle
- import numpy as np
- from sklearn.neural_network import MLPClassifier
- # 加载数据
- # 设置编码,解决异常:UnicodeDecodeError: 'ascii' codec can't decode byte 0x90 in position 614: ordinal not in range(128)
- with gzip.open("E:/mnist.pkl.gz") as fp:
- training_data, valid_data, test_data = pickle.load(fp, encoding='bytes')
- x_training_data, y_training_data = training_data
- x_valid_data, y_valid_data = valid_data
- x_test_data, y_test_data = test_data
- classes = np.unique(y_test_data)
- # 将验证集和训练集合并
- x_training_data_final = np.vstack((x_training_data, x_valid_data))
- y_training_data_final = np.append(y_training_data, y_valid_data)
- # 设置神经网络模型参数
- # 使用solver='lbfgs',拟牛顿法,需要较多的跌点次数
- lbfgs = MLPClassifier(solver='lbfgs', activation='relu', alpha=1e-4, hidden_layer_sizes=(50, 50), random_state=1, max_iter=10, verbose=10, learning_rate_init=0.1)
- # 使用solver='adam',基于随机梯度下降的优化算法,准确率较低
- adam = MLPClassifier(solver='adam', activation='relu', alpha=1e-4, hidden_layer_sizes=(50, 50), random_state=1, max_iter=10, verbose=10, learning_rate_init=0.1)
- # 使用solver='sgd',基于梯度下降的自适应优化算法,分批训练数据,效率高,准确性高,建议使用
- sgd = MLPClassifier(solver='sgd', activation='relu', alpha=1e-4, hidden_layer_sizes=(50, 50), random_state=1, max_iter=10, verbose=10, learning_rate_init=0.1)
- # 使用不同算法训练模型
- lbfgs.fit(x_training_data_final, y_training_data_final)
- adam.fit(x_training_data_final, y_training_data_final)
- sgd.fit(x_training_data_final, y_training_data_final)
- # 预测
- lbfgs_predict = lbfgs.predict(x_test_data)
- adam_predict = adam.predict(x_test_data)
- sgd_predict = sgd.predict(x_test_data)
- print(lbfgs_predict)
- print("*******************************************")
- print(adam_predict)
- print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
- print(sgd_predict)
- print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
- # 评估模型
- print(lbfgs.score(x_test_data, y_test_data))
- print("===========================================")
- print(adam.score(x_test_data, y_test_data))
- print("-------------------------------------------")
- print(sgd.score(x_test_data, y_test_data))
- # 输出正确结果
- print(y_test_data)
结果:
max_iter=10
max_iter=20
注意:
1. 当使用pickle加载mnist数据时,python3.x与python2.x差距较大,python3.x会抛出异常,异常信息为:UnicodeDecodeError: 'ascii' codec can't decode byte 0x90 in position 614: ordinal not in range(128)
此时需要指定编码pickle.load(fp, encoding='bytes')来解决异常!
2. 比较lbfgs(拟牛顿法)、adam(基于随机梯度下降的优化算法)和sgd(基于梯度下降的自适应优化算法)可知,lbfgs波动较大,在相同训练数据的情况下,当迭代次数不同时,模型预测准确率波动较大。adam算法模型训练较快,但模型预测准确率较差,适合应用在预测准确率要求不高,响应时间短的地方。sgd算法在模型训练速度和预测准确率方面都能达到较好的效果,建议使用!
神经网络MPLClassifier分类的更多相关文章
- 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)
1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...
- 深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类
使用tensorflow构造神经网络用来进行mnist数据集的分类 相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数, 神经网络使用的损失值为 ...
- 深度学习原理与框架-卷积神经网络-cifar10分类(图片分类代码) 1.数据读入 2.模型构建 3.模型参数训练
卷积神经网络:下面要说的这个网络,由下面三层所组成 卷积网络:卷积层 + 激活层relu+ 池化层max_pool组成 神经网络:线性变化 + 激活层relu 神经网络: 线性变化(获得得分值) 代码 ...
- TensorFlow.NET机器学习入门【4】采用神经网络处理分类问题
上一篇文章我们介绍了通过神经网络来处理一个非线性回归的问题,这次我们将采用神经网络来处理一个多元分类的问题. 这次我们解决这样一个问题:输入一个人的身高和体重的数据,程序判断出这个人的身材状况,一共三 ...
- 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_上
完整项目见:Github 完整项目中最终使用了ResNet进行分类,而卷积版本较本篇中结构为了提升训练效果也略有改动 本节主要介绍进阶的卷积神经网络设计相关,数据读入以及增强在下一节再与介绍 网络相关 ...
- 深度学习原理与框架-神经网络-cifar10分类(代码) 1.np.concatenate(进行数据串接) 2.np.hstack(将数据横着排列) 3.hasattr(判断.py文件的函数是否存在) 4.reshape(维度重构) 5.tanspose(维度位置变化) 6.pickle.load(f文件读入) 7.np.argmax(获得最大值索引) 8.np.maximum(阈值比较)
横1. np.concatenate(list, axis=0) 将数据进行串接,这里主要是可以将列表进行x轴获得y轴的串接 参数说明:list表示需要串接的列表,axis=0,表示从上到下进行串接 ...
- Keras人工神经网络多分类(SGD)
import numpy as np import pandas as pd from keras.models import Sequential from keras.layers import ...
- 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_下
数据读取部分实现 文中采用了tensorflow的从文件直接读取数据的方式,逻辑流程如下, 实现如下, # Author : Hellcat # Time : 2017/12/9 import os ...
- 在 TensorFlow 中实现文本分类的卷积神经网络
在TensorFlow中实现文本分类的卷积神经网络 Github提供了完整的代码: https://github.com/dennybritz/cnn-text-classification-tf 在 ...
随机推荐
- 理解ScheduledExecutorService中scheduleAtFixedRate和scheduleWithFixedDelay的区别
scheduleAtFixedRate 每间隔一段时间执行,分为两种情况: 当前任务执行时间小于间隔时间,每次到点即执行: /** * 任务执行时间(8s)小于间隔时间(10s) */ public ...
- 关于Mybatis的一些随笔
Mapper.xml头文件 <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http:/ ...
- 【JDBC 笔记】
JDBC 笔记 作者:晨钟暮鼓c个人微信公众号:程序猿的月光宝盒 对应pdf版:https://download.csdn.net/download/qq_22430159/10754554 没有积分 ...
- 全球第一免费开源ERP Odoo Ubuntu最佳开发环境独家首发分享
起源 近年来随着国内的互联网经济的快速腾飞,诞生了很多开源软件创造的市场价值以及企业价值神话,特别是对于企业ERP领域,一直以来都是高昂的国内外产品充实,国内的中小成长型企业越来越需要一套好看又能打, ...
- github SSH配置
目录 github SSH配置 前言 ssh 配置 github SSH配置 前言 github有两种更新的渠道,一种是https的,一种是ssh的,其中https每次都要输入密码,非常烦.所以,最好 ...
- (四)图数据neo4j用户管理
1.用户管理 neo4j可通过内置函数,进行用户的创建.查看.删除. (1)用户创建; CALL dbms.security.createUser(name,password,requridchang ...
- mysql用户创建授权
创建用户: grant select,update,insert,delete,create,drop,alter,index on *.* to 'jyx_mysql'@'%' identified ...
- 使用 Parallels Destop 最小化安装 centOS 操作系统
1. 环境准备 macOS 操作系统 Parallels Destop 13 CentOS 7.6 Minimal ISO 镜像文件 2. 新建操作系统 选择下载好的 CentosOS 7.6 即 C ...
- QPainterPath 不规则提示框
currentPosition()是最后一次绘制后的“结束点”(或初始点),使用moveTo()移动currentPosition()而不会添加任何元素. QPainterPath 合并: 1.方法 ...
- 【JVM系列】一步步解析java执行内幕
对于任何一门语言,要想达到精通的水平,研究它的执行原理(或者叫底层机制)不失为一种良好的方式.在本篇文章中,将重点研究java源代码的执行原理,即从程 序员编写JAVA源代码,到最终形成产品,在整个过 ...