matlab练习程序(神经网络分类)
注:这里的练习鉴于当时理解不完全,可能会有些错误,关于神经网络的实践可以参考我的这篇博文
这里的代码只是简单的练习,不涉及代码优化,也不涉及神经网络优化,所以我用了最能体现原理的方式来写的代码。
激活函数用的是h = 1/(1+exp(-y)),其中y=sum([X Y].*w)。
代价函数用的是E = 1/2*(t-h)^2,其中t为目标值,t为1代表是该类,t为0代表不是该类。
权值更新采用BP算法。
网络1形式如下,没有隐含层,1个偏置量,输入直接连接输出:
分类结果:
代码如下:
clear all;
close all;
clc; n=;
randn('seed',);
mu1=[ ];
S1=[0.5 ;
0.5];
P1=mvnrnd(mu1,S1,n); mu2=[ ];
S2=[0.5 ;
0.5];
P2=mvnrnd(mu2,S2,n); mu3=[ ];
S3=[0.5 ;
0.5];
P3=mvnrnd(mu3,S3,n); P=[P1;P2;P3];
meanP=mean(P); P=[P(:,)-meanP() P(:,)-meanP()]; sigma = ; X=P(:,);
Y=P(:,);
B=rand(*n,); w1 = rand(*n,);
w2 = rand(*n,);
w3 = rand(*n,); w4 = rand(*n,);
w5 = rand(*n,);
w6 = rand(*n,); for i=:*n
i
while y1 = X(i)*w1(i) + Y(i)*w4(i) + B(i);
y2 = X(i)*w2(i) + Y(i)*w5(i) + B(i);
y3 = X(i)*w3(i) + Y(i)*w6(i) + B(i); h1 = /(+exp(-y1));
h2 = /(+exp(-y2));
h3 = /(+exp(-y3)); e1 = /*( - h1)^;
e2 = /*( - h2)^;
e3 = /*( - h3)^; if i<=n && e1<=0.0000001
break;
elseif i>n && i<=*n && e2<0.0000001
break;
elseif i>*n && e3<0.0000001
break;
end if i<=n
w1(i) = w1(i)-sigma*(h1-)*h1*(-h1)*X(i);
w2(i) = w2(i)-sigma*(h2-)*h2*(-h2)*X(i);
w3(i) = w3(i)-sigma*(h3-)*h3*(-h3)*X(i); w4(i) = w4(i)-sigma*(h1-)*h1*(-h1)*Y(i);
w5(i) = w5(i)-sigma*(h2-)*h2*(-h2)*Y(i);
w6(i) = w6(i)-sigma*(h3-)*h3*(-h3)*Y(i); B(i) =B(i)- sigma*((h1-)*h1*(-h1)+(h2-)*h2*(-h2)+(h3-)*h3*(-h3));
elseif i>n && i<=*n
w1(i) = w1(i)-sigma*(h1-)*h1*(-h1)*X(i);
w2(i) = w2(i)-sigma*(h2-)*h2*(-h2)*X(i);
w3(i) = w3(i)-sigma*(h3-)*h3*(-h3)*X(i); w4(i) = w4(i)-sigma*(h1-)*h1*(-h1)*Y(i);
w5(i) = w5(i)-sigma*(h2-)*h2*(-h2)*Y(i);
w6(i) = w6(i)-sigma*(h3-)*h3*(-h3)*Y(i); B(i) =B(i)- sigma*((h1-)*h1*(-h1)+(h2-)*h2*(-h2)+(h3-)*h3*(-h3));
else
w1(i) = w1(i)-sigma*(h1-)*h1*(-h1)*X(i);
w2(i) = w2(i)-sigma*(h2-)*h2*(-h2)*X(i);
w3(i) = w3(i)-sigma*(h3-)*h3*(-h3)*X(i); w4(i) = w4(i)-sigma*(h1-)*h1*(-h1)*Y(i);
w5(i) = w5(i)-sigma*(h2-)*h2*(-h2)*Y(i);
w6(i) = w6(i)-sigma*(h3-)*h3*(-h3)*Y(i); B(i) =B(i)- sigma*((h1-)*h1*(-h1)+(h2-)*h2*(-h2)+(h3-)*h3*(-h3));
end end
end plot(P(:,),P(:,),'o');
hold on; flag = ;
M=[];
for x=-:0.3:
for y=-:0.3: H=[];
for i=:*n
y1 = x*w1(i)+y*w4(i) +B(i);
y2 = x*w2(i)+y*w5(i) +B(i);
y3 = x*w3(i)+y*w6(i) +B(i);
h1=/(+exp(-y1));
h2=/(+exp(-y2));
h3=/(+exp(-y3)); H=[H;h1 h2 h3];
end
% H1 = mean(H(:n,));
% H2 = mean(H(n:*n,));
% H3 = mean(H(*n:*n,)); meanH = mean(H);
H1 = meanH();
H2 = meanH();
H3= meanH();
if H1>H2 && H1>H3
plot(x,y,'g.')
elseif H2 > H1 && H2 > H3
plot(x,y,'r.')
elseif H3 > H1 && H3 > H2
plot(x,y,'b.')
end end
end
网络2形式如下,有1个隐含层,2个偏置量:
分类结果:
代码如下:
clear all;
close all;
clc; n=;
randn('seed',);
mu1=[ ];
S1=[0.5 ;
0.5];
P1=mvnrnd(mu1,S1,n); mu2=[ ];
S2=[0.5 ;
0.5];
P2=mvnrnd(mu2,S2,n); mu3=[ ];
S3=[0.5 ;
0.5];
P3=mvnrnd(mu3,S3,n); P=[P1;P2;P3];
meanP=mean(P); P=[P(:,)-meanP() P(:,)-meanP()]; sigma = ; X=P(:,);
Y=P(:,); B1=rand(*n,);
B2=rand(*n,); w1 = rand(*n,);
w2 = rand(*n,); w3 = rand(*n,);
w4 = rand(*n,);
w5 = rand(*n,); for i=:*n
i
while y0 = X(i)*w1(i) + Y(i)*w2(i) + B1(i);
h0 = /(+exp(-y0)); y1 = h0*w3(i) + B2(i);
y2 = h0*w4(i) + B2(i);
y3 = h0*w5(i) + B2(i); h1 = /(+exp(-y1));
h2 = /(+exp(-y2));
h3 = /(+exp(-y3)); e1 = /*( - h1)^;
e2 = /*( - h2)^;
e3 = /*( - h3)^; if i<=n && e1<=0.0000001
break;
elseif i>n && i<=*n && e2<0.0000001
break;
elseif i>*n && e3<0.0000001
break;
end %e1
if i<=n w1(i) = w1(i)- sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0)*X(i) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0)*X(i) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)*X(i));
w2(i) = w2(i)- sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0)*Y(i) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0)*Y(i) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)*Y(i));
B1(i) = B1(i)- sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)); w3(i) = w3(i)-sigma*(h1-)*h1*(-h1)*h0;
w4(i) = w4(i)-sigma*(h2-)*h2*(-h2)*h0;
w5(i) = w5(i)-sigma*(h3-)*h3*(-h3)*h0;
B2(i) =B2(i)- sigma*((h1-)*h1*(-h1)+(h2-)*h2*(-h2)+(h3-)*h3*(-h3)); elseif i>n && i<=*n
w1(i) = w1(i)-sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0)*X(i) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0)*X(i) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)*X(i));
w2(i) = w2(i)-sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0)*Y(i) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0)*Y(i) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)*Y(i));
B1(i) =B1(i)- sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)); w3(i) = w3(i)-sigma*(h1-)*h1*(-h1)*h0;
w4(i) = w4(i)-sigma*(h2-)*h2*(-h2)*h0;
w5(i) = w5(i)-sigma*(h3-)*h3*(-h3)*h0;
B2(i) =B2(i)- sigma*((h1-)*h1*(-h1)+(h2-)*h2*(-h2)+(h3-)*h3*(-h3)); else
w1(i) = w1(i)-sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0)*X(i) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0)*X(i) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)*X(i));
w2(i) = w2(i)-sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0)*Y(i) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0)*Y(i) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)*Y(i));
B1(i) =B1(i)- sigma*((h1-)*h1*(-h1)*w3(i)*h0*(-h0) + (h2-)*h2*(-h2)*w4(i)*h0*(-h0) + (h3-)*h3*(-h3)*w5(i)*h0*(-h0)); w3(i) = w3(i)-sigma*(h1-)*h1*(-h1)*h0;
w4(i) = w4(i)-sigma*(h2-)*h2*(-h2)*h0;
w5(i) = w5(i)-sigma*(h3-)*h3*(-h3)*h0;
B2(i) =B2(i)- sigma*((h1-)*h1*(-h1)+(h2-)*h2*(-h2)+(h3-)*h3*(-h3)); end end
end plot(P(:,),P(:,),'o');
hold on; flag = ;
M=[];
for x=-:0.3:
for y=-:0.3: H=[];
for i=:*n
y0 = x*w1(i)+y*w2(i) +B1(i);
h0=/(+exp(-y0)); y1 = h0*w3(i) + B2(i);
y2 = h0*w4(i) + B2(i);
y3 = h0*w5(i) + B2(i); h1 =/(+exp(-y1));
h2 =/(+exp(-y2));
h3 =/(+exp(-y3)); H=[H;h1 h2 h3];
end meanH = mean(H);
H1 = meanH();
H2 = meanH();
H3= meanH();
if H1>H2 && H1>H3
plot(x,y,'g.')
elseif H2 > H1 && H2 > H3
plot(x,y,'r.')
elseif H3 > H1 && H3 > H2
plot(x,y,'b.')
end end
end
网络3形式如下,有2个隐含层,2个偏置量:
分类结果:
代码如下:
clear all;
close all;
clc; n=;
randn('seed',);
mu1=[ ];
S1=[0.5 ;
0.5];
P1=mvnrnd(mu1,S1,n); mu2=[ ];
S2=[0.5 ;
0.5];
P2=mvnrnd(mu2,S2,n); mu3=[ ];
S3=[0.5 ;
0.5];
P3=mvnrnd(mu3,S3,n); P=[P1;P2;P3];
meanP=mean(P); P=[P(:,)-meanP() P(:,)-meanP()]; sigma = ; X=P(:,);
Y=P(:,); B1=rand(*n,);
B2=rand(*n,); w1 = rand(*n,);
w2 = rand(*n,); w3 = rand(*n,);
w4 = rand(*n,); w5 = rand(*n,);
w6 = rand(*n,);
w7 = rand(*n,); w8 = rand(*n,);
w9 = rand(*n,);
w10 = rand(*n,); for i=:*n
i
while y1 = X(i)*w1(i) + Y(i)*w3(i) + B1(i);
y2 = X(i)*w2(i) + Y(i)*w4(i) + B1(i); h1 = /(+exp(-y1));
h2 = /(+exp(-y2)); dh1 = h1*(-h1);
dh2 = h2*(-h2); y3 = h1*w5(i) + h2*w8(i)+ B2(i);
y4 = h1*w6(i) + h2*w9(i)+ B2(i);
y5 = h1*w7(i) + h2*w10(i)+ B2(i); h3 = /(+exp(-y3));
h4 = /(+exp(-y4));
h5 = /(+exp(-y5)); dh3 = h3*(-h3);
dh4 = h4*(-h4);
dh5 = h5*(-h5); e1 = /*( - h3)^;
e2 = /*( - h4)^;
e3 = /*( - h5)^; if i<=n && e1<=0.0000001
break;
elseif i>n && i<=*n && e2<0.0000001
break;
elseif i>*n && e3<0.0000001
break;
end %e1
if i<=n w1(i) = w1(i) -sigma * ((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i)) * dh1*X(i);
w2(i) = w2(i) -sigma * ((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i)) * dh2*X(i); w3(i) = w3(i) -sigma * ((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i)) * dh1*Y(i);
w4(i) = w4(i) -sigma * ((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i)) * dh2*Y(i); B1(i) = B1(i)- sigma*(((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i))*dh1+((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i))*dh2); w5(i) = w5(i)-sigma*(h3-)*dh3*h1;
w6(i) = w6(i)-sigma*(h4-)*dh4*h1;
w7(i) = w7(i)-sigma*(h5-)*dh5*h1; w8(i) = w8(i)-sigma*(h3-)*dh3*h2;
w9(i) = w9(i)-sigma*(h4-)*dh4*h2;
w10(i) = w10(i)-sigma*(h5-)*dh5*h2; B2(i) =B2(i)- sigma*((h3-)*dh3+(h4-)*dh4+(h5-)*dh5); elseif i>n && i<=*n
w1(i) = w1(i) -sigma * ((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i)) * dh1*X(i);
w2(i) = w2(i) -sigma * ((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i)) * dh2*X(i); w3(i) = w3(i) -sigma * ((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i)) * dh1*Y(i);
w4(i) = w4(i) -sigma * ((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i)) * dh2*Y(i); B1(i) = B1(i)- sigma*(((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i))*dh1+((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i))*dh2); w5(i) = w5(i)-sigma*(h3-)*dh3*h1;
w6(i) = w6(i)-sigma*(h4-)*dh4*h1;
w7(i) = w7(i)-sigma*(h5-)*dh5*h1; w8(i) = w8(i)-sigma*(h3-)*dh3*h2;
w9(i) = w9(i)-sigma*(h4-)*dh4*h2;
w10(i) = w10(i)-sigma*(h5-)*dh5*h2; B2(i) =B2(i)- sigma*((h3-)*dh3+(h4-)*dh4+(h5-)*dh5); else
w1(i) = w1(i) -sigma * ((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i)) * dh1*X(i);
w2(i) = w2(i) -sigma * ((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i)) * dh2*X(i); w3(i) = w3(i) -sigma * ((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i)) * dh1*Y(i);
w4(i) = w4(i) -sigma * ((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i)) * dh2*Y(i); B1(i) = B1(i)- sigma*(((h3-)*dh3*w5(i)+(h4-)*dh4*w6(i)+(h5-)*dh5*w7(i))*dh1+((h3-)*dh3*w8(i)+(h4-)*dh4*w9(i)+(h5-)*dh5*w10(i))*dh2); w5(i) = w5(i)-sigma*(h3-)*dh3*h1;
w6(i) = w6(i)-sigma*(h4-)*dh4*h1;
w7(i) = w7(i)-sigma*(h5-)*dh5*h1; w8(i) = w8(i)-sigma*(h3-)*dh3*h2;
w9(i) = w9(i)-sigma*(h4-)*dh4*h2;
w10(i) = w10(i)-sigma*(h5-)*dh5*h2; B2(i) =B2(i)- sigma*((h3-)*dh3+(h4-)*dh4+(h5-)*dh5); end end
end plot(P(:,),P(:,),'o');
hold on; flag = ;
M=[];
for x=-:0.3:
for y=-:0.3:
% x=-;
% y=;
H=[];
for i=:*n
y1 = x*w1(i) + y*w3(i) + B1(i);
y2 = x*w2(i) + y*w4(i) + B1(i); h1 = /(+exp(-y1));
h2 = /(+exp(-y2)); dh1 = h1*(-h1);
dh2 = h2*(-h2); y3 = h1*w5(i) + h2*w8(i)+ B2(i);
y4 = h1*w6(i) + h2*w9(i)+ B2(i);
y5 = h1*w7(i) + h2*w10(i)+ B2(i); h3 = /(+exp(-y3));
h4 = /(+exp(-y4));
h5 = /(+exp(-y5)); H=[H;h3 h4 h5];
end
% H1 = mean(H(:n,));
% H2 = mean(H(n+:*n,));
% H3 = mean(H(*n+:*n,)); meanH = mean(H);
H1 = meanH();
H2 = meanH();
H3= meanH(); M=[M;H1 H2 H3 x y];
if H1>H2 && H1>H3
plot(x,y,'g.')
elseif H2 > H1 && H2 > H3
plot(x,y,'r.')
elseif H3 > H1 && H3 > H2
plot(x,y,'b.')
end end
end
后面我计划对网络分别使用softmax,权重初始化,正则化,ReLu激活函数,交叉熵代价函数与卷积的形式进行优化。
matlab练习程序(神经网络分类)的更多相关文章
- matlab练习程序(神经网络识别mnist手写数据集)
记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...
- 详细MATLAB 中BP神经网络算法的实现
MATLAB 中BP神经网络算法的实现 BP神经网络算法提供了一种普遍并且实用的方法从样例中学习值为实数.离散值或者向量的函数,这里就简单介绍一下如何用MATLAB编程实现该算法. 具体步骤 这里 ...
- matlab练习程序(SUSAN检测)
matlab练习程序(SUSAN检测) SUSAN算子既可以检测角点也可以检测边缘,不过角点似乎比不过harris,边缘似乎比不过Canny.不过思想还是有点意思的. 主要思想就是:首先做一个和原图像 ...
- sklearn神经网络分类
sklearn神经网络分类 神经网络学习能力强大,在数据量足够,隐藏层足够多的情况下,理论上可以拟合出任何方程. 理论部分 sklearn提供的神经网络算法有三个: neural_network.Be ...
- (转)matlab练习程序(HOG方向梯度直方图)
matlab练习程序(HOG方向梯度直方图)http://www.cnblogs.com/tiandsp/archive/2013/05/24/3097503.html HOG(Histogram o ...
- matlab练习程序(异或分类)
clear all; close all; clc; %生成两组已标记数据 randn(); mu1=[ ]; S1=[; 0.5]; P1=mvnrnd(mu1,S1,); mu2=[ ]; S2= ...
- Matlab的BP神经网络工具箱及其在函数逼近中的应用
1.神经网络工具箱概述 Matlab神经网络工具箱几乎包含了现有神经网络的最新成果,神经网络工具箱模型包括感知器.线性网络.BP网络.径向基函数网络.竞争型神经网络.自组织网络和学习向量量化网络.反馈 ...
- matlab任务:FCM分类
一个朋友让帮忙做图像分类,用FCM聚类算法,网上查了一下,FCM基本都是对一幅图像进行像素的分类,跟他说的任务不太一样,所要做的是将一个文件夹里的一千多幅图像进行分类.图像大概是这个样子的(是25*2 ...
- [翻译]LSP程序的分类
翻译的太垃圾,不建议其它人阅读本文. Note:LSP现在已经不推荐使用.自windows8和windows Server2012开始,使用Windows Filtering Platform. Wi ...
随机推荐
- 使用FormData格式在前后端传递数据
为什么一定要使用formdata格式……很大原因是因为当时我犯蠢…… 前端肯定是JS了,具体不写了,使用Postman测试,后端语言是Java,框架Spring Boot,使用IntelliJ IDE ...
- centos7下安装samba服务器
samba笔记: http://services.linuxpanda.tech/%E7%BD%91%E7%BB%9C%E6%96%87%E4%BB%B6%E5%85%B1%E4%BA%AB/samb ...
- Thrift架构介绍
Thrift是一个跨语言的服务部署框架,最初由Facebook于2007年开发,2008年进入Apache开源项目.Thrift通过一个中间语言(IDL, 接口定义语言)来定义RPC的接口和数据类型, ...
- [总结]数论和组合计数类数学相关(定理&证明&板子)
0 写在前面 0.0 前言 由于我太菜了,导致一些东西一学就忘,特开此文来记录下最让我头痛的数学相关问题. 一些引用的文字都注释了原文链接,若侵犯了您的权益,敬请告知:若文章中出现错误,也烦请告知. ...
- MySQL中间件之ProxySQL(8):SQL语句的重写规则
返回ProxySQL系列文章:http://www.cnblogs.com/f-ck-need-u/p/7586194.html 1.为什么要重写SQL语句 ProxySQL在收到前端发送来的SQL语 ...
- Hyperledger Fabric之模型
本文主要介绍Hyperledger Fabric的主要设计特点,为了满足功能丰富.可定制.企业化区块链解决方案. Assets - 资产定义,使得任何形式的资产,从食物到汽车到货币都可以进行自由的交换 ...
- 第一册:lesson3-4.
原文: A:My coat and my umbrella please?Here is my ticket. B:Thank you sir.Number five.Here is your umb ...
- Intellij idea 项目目录设置 与包的显示创建
1.把目录设置成为层级结构显示.和eclipse类似 去掉flatten Packages前面的勾 在项目中创建多级包的时候要注意,必须在Java下建,并且要全输入才能识别
- 介绍一款文档神器:pandoc
http://pandoc.org/ 因为工作需要,将一批markdown的文档转换成word文档,找来找去,这个pandoc真是神器 啊,推荐给大家 If you need to convert f ...
- vs2010打不开vs2017的.sln文件,出现错误提示 “选择的文件是解决方案文件 但是用此应用程序的较新版本创建的,无法打开”
解决方案: 1.复制下面这段语句 Microsoft Visual Studio Solution File, Format Version 11.00 # Visual Studio 2010 2. ...