Keras人工神经网络多分类(SGD)
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.preprocessing import LabelEncoder
from keras.optimizers import SGD
from keras.layers import LSTM # load dataset
dataframe = pd.read_csv("./data/iris1.csv", header=None)
dataset = dataframe.values
X = dataset[:, 0:19].astype(float)
dummy_y1 = dataset[:, 19]
m,n=1682,6
dum_imax=np.zeros((m,n))
# print(type(dum_imax))
for i in range(m):
# print(i)
# exit()
if dummy_y1[i]!=0:
dum_imax[i][dummy_y1[i]-1]=1
else:
dum_imax[i][5]=1
# print(dum_imax)
dummy_y =dum_imax
print(dummy_y)
print(type(dummy_y[0][0])) def baseline_model():
model = Sequential()
model.add(Dense(output_dim=50, input_dim=19, activation='relu'))
# model.add(LSTM(128))
model.add(Dropout(0.4))
model.add(Dense(output_dim=6, input_dim=50, activation='softmax'))
# Compile model
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
# model.compile(loss='categorical_crossentropy', optimizer=sgd)
#编译模型。由于我们做的是二元分类,所以我们指定损失函数为binary_crossentropy,以及模式为binary
#另外常见的损失函数还有mean_squared_error、categorical_crossentropy等,请阅读帮助文件。
#求解方法我们指定用adam,还有sgd、rmsprop等可选
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
return model
estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=40, batch_size=256)
print(estimator) # splitting data into training set and test set. If random_state is set to an integer, the split datasets are fixed.
X_train, X_test, Y_train, Y_test = train_test_split(X, dummy_y, test_size=0.2, random_state=0)#train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,
print(len(X_train[0]))
print(len(Y_train[0]))
estimator.fit(X_train, Y_train,nb_epoch = 100)#训练模型,学习一百次 # make predictions
print(X_test)
pred = estimator.predict(X_test)
print(pred)
# init_lables = encoder.inverse_transform(pred)
# print(init_lables) # inverse numeric variables to initial categorical labels
# init_lables = encoder.inverse_transform(pred)
# print(init_lables) # k-fold cross-validate
# seed = 42
# np.random.seed(seed)
'''
n_splits : 默认3,最小为2;K折验证的K值
shuffle : 默认False;shuffle会对数据产生随机搅动(洗牌)
random_state :默认None,随机种子
'''
kfold = KFold(n_splits=5, shuffle=True)#定义5折,在对数据进行划分之前,对数据进行随机混洗 results = cross_val_score(estimator, X, dummy_y, cv=kfold)#在数据集上,使用k fold交叉验证,对估计器estimator进行评估。
print("baseline:%.2f%%(%.2f%%)"%(results.mean()*100,results.std()*100))#返回的结果,是10次数据集划分后,每次的评估结果。评估结果包括平均准确率和标准差
Keras人工神经网络多分类(SGD)的更多相关文章
- keras人工神经网络构建入门
//2019.07.29-301.Keras 是提供一些高度可用神经网络框架的 Python API ,能帮助你快速的构建和训练自己的深度学习模型,它的后端是 TensorFlow 或者 Theano ...
- [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)
3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...
- neurosolutions 人工神经网络集成开发环境 keras
人工神经网络集成开发环境 : http://www.neurosolutions.com/ keras: https://github.com/fchollet/keras 文档 http ...
- [DL学习笔记]从人工神经网络到卷积神经网络_2_卷积神经网络
先一层一层的说卷积神经网络是啥: 1:卷积层,特征提取 我们输入这样一幅图片(28*28): 如果用传统神经网络,下一层的每个神经元将连接到输入图片的每一个像素上去,但是在卷积神经网络中,我们只把输入 ...
- [DL学习笔记]从人工神经网络到卷积神经网络_1_神经网络和BP算法
前言:这只是我的一个学习笔记,里边肯定有不少错误,还希望有大神能帮帮找找,由于是从小白的视角来看问题的,所以对于初学者或多或少会有点帮助吧. 1:人工全连接神经网络和BP算法 <1>:人工 ...
- 人工神经网络 Artificial Neural Network
2017-12-18 23:42:33 一.什么是深度学习 深度学习(deep neural network)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高 ...
- SIGAI深度学习第二集 人工神经网络1
讲授神经网络的思想起源.神经元原理.神经网络的结构和本质.正向传播算法.链式求导及反向传播算法.神经网络怎么用于实际问题等 课程大纲: 神经网络的思想起源 神经元的原理 神经网络结构 正向传播算法 怎 ...
- 机器学习笔记之人工神经网络(ANN)
人工神经网络(ANN)提供了一种普遍而且实际的方法从样例中学习值为实数.离散值或向量函数.人工神经网络由一系列简单的单元相互连接构成,其中每个单元有一定数量的实值输入,并产生单一的实值输出. 上面是一 ...
- 人工神经网络简介和单层网络实现AND运算--AForge.NET框架的使用(五)
原文:人工神经网络简介和单层网络实现AND运算--AForge.NET框架的使用(五) 前面4篇文章说的是模糊系统,它不同于传统的值逻辑,理论基础是模糊数学,所以有些朋友看着有点迷糊,如果有兴趣建议参 ...
随机推荐
- SparseArray
使用SparseArray更加节省内存空间的使用,SparseArray也是以key和value对数据进行保存的.使用的时候只需要指定value的类型即可.并且key不需要封装成对象类型. Has ...
- 31-java中知识总结:list, set, map, stack, queue
import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Lin ...
- 15-算法训练 P1103
http://lx.lanqiao.cn/problem.page?gpid=T372 算法训练 P1103 时间限制:1.0s 内存限制:256.0MB 编程实现两个复数的运算 ...
- ApplicationContextAware的使用
一.这个接口有什么用? 当一个类实现了这个接口(ApplicationContextAware)之后,这个类就可以方便获得ApplicationContext中的所有bean.换句话说,就是这个类可以 ...
- first H5
<!DOCTYPE html> <html> <head> <meta http-equiv="Content-Type" content ...
- [Z]haproxy+keepalived高可用群集
http://blog.51cto.com/13555423/2067131 Haproxy是目前比较流行的一种集群调度工具Haproxy 与LVS.Nginx的比较LVS性能最好,但是搭建相对复杂N ...
- 定时器中的this和函数封装的简单理解;
一.定时器中的this: 不管定时器中的函数怎么写,它里面的this都是window: 在函数前面讲this赋值给一个变量,函数内使用这个变量就可以改变this的指向 二.函数封装 函数封装是一种函数 ...
- 5I - 汉诺塔IV
还记得汉诺塔III吗?他的规则是这样的:不允许直接从最左(右)边移到最右(左)边(每次移动一定是移到中间杆或从中间移出),也不允许大盘放到小盘的上面.xhd在想如果我们允许最大的盘子放到最上面会怎么样 ...
- windows上安装RabbitMQ
windows下 安装 rabbitMQ 及操作常用命令 rabbitMQ是一个在AMQP协议标准基础上完整的,可服用的企业消息系统.它遵循Mozilla Public License开源协议,采用 ...
- Java的反射技术
什么是反射机制 Java的反射机制是在运行状态中,对于任意一个类,都能够知道这个类的所有属性和方法:对于任意一个对象,都能调用它的任意属性和方法.这种动态获取信息以及动态调用对象属性和方法的即使称为J ...