初学神经网络算法--梯度下降、反向传播、优化(交叉熵代价函数、L2规范化) 柔性最大值(softmax)还未领会其要义,之后再说

有点懒,暂时不想把算法重新总结,先贴一个之前做过的反向传播的总结ppt

其实python更好实现些,不过我想好好学matlab,就用matlab写了

然后是算法源码,第一个啰嗦些,不过可以帮助理解算法

function bpback1(ny,eta,mini_size,epoch)
%ny:隐藏层为1层,神经元数目为ny;eta:学习速率;mini_size:最小采样;eopch:迭代次数
%该函数为梯度下降+反向传播
%images
[numimages,images]=bpimages('train-images.idx3-ubyte');
[n_test,test_data_x]=bpimages('t10k-images.idx3-ubyte');
%labels
[numlabels,labels]=bplabels('train-labels.idx1-ubyte');
[n_test,test_data_y]=bplabels('t10k-labels.idx1-ubyte');
%init w/b
%rand('state',sum(100*clock));
%ny=30;eta=0.01;mini_size=10;
w1=randn(ny,784);
b1=randn(ny,1);
w2=randn(10,ny);
b2=randn(10,1);
for epo=1:epoch
for nums=1:numimages/mini_size
    for num=(nums-1)*mini_size+1:nums*mini_size
        x=images(:,num);
        y=labels(:,num);
    net2=w1*x;               %input of net2
    for i=1:ny
    hidden(i)=1/(1+exp(-net2(i)-b1(i)));%output of net2
    end
    net3=w2*hidden';            %input of net3
    for i=1:10
    o(i)=1/(1+exp(-net3(i)-b2(i)));%output of net3
    end

    %back
    for i=1:10
    delta3(i)=(y(i)-o(i))*o(i)*(1-o(i));%delta of net3
    end
    for i=1:ny
    delta2(i)=delta3*w2(:,i)*hidden(i)*(1-hidden(i));%delta of net2
    end
    %updata w/b
    for i=1:10
        for j=1:ny
    w2(i,j)=w2(i,j)+eta*delta3(i)*hidden(j)/mini_size;
        end
    end
    for i=1:ny
        for j=1:784
    w1(i,j)=w1(i,j)+eta*delta2(i)*x(j)/mini_size;
        end
    end
    for i=1:10
    b2(i)=b2(i)+eta*delta3(i);
    end
    for i=1:ny
    b1(i)=b1(i)+eta*delta2(i);
    end
    end
end
%calculate sum of error
%accuracy
sum0=0;
for i=1:1000
    x0=test_data_x(:,i);
    y0=test_data_y(:,i);
    a1=[];
    a2=[];
    s1=w1*x0;
    for j=1:ny
    a1(j)=1/(1+exp(-s1(j)-b1(j)));
    end
    s2=w2*a1';
    for j=1:10
    a2(j)=1/(1+exp(-s2(j)-b2(j)));
    end
    a2=a2';
    [m1,n1]=max(a2);
    [m2,n2]=max(y0);
    if n1==n2
        sum0=sum0+1;
    end
    %e=o'-y;
    %sigma(num)=e'*e;
    sigma(i)=sumsqr(a2-y0);   %代价为误差平方和
end
sigmas(epo)=sum(sigma)/(2*1000);
fprintf('epoch %d:%d/%d\n',epo,sum0,1000);
end
plot(sigmas);
xlabel('epoch');
ylabel('cost on the training_data');
end

  

function bpback2(ny,eta,mini_size,epoch,numda)
%ny:隐藏层为1层,神经元数目为ny;eta:学习速率;mini_size:最小采样;eopch:迭代次数
%bpback的优化,包括L2规范化、交叉熵代价函数的引入---结果证明该优化非常赞!
%images
[numimages,images]=bpimages('train-images.idx3-ubyte');
[n_test,test_data_x]=bpimages('t10k-images.idx3-ubyte');
%labels
[numlabels,labels]=bplabels('train-labels.idx1-ubyte');
[n_test,test_data_y]=bplabels('t10k-labels.idx1-ubyte');
%init w/b
%ny=30;eta=0.05;mini_size=10;epoch=10;numda=0.1;
rand('state',sum(100*clock));
w1=randn(ny,784)/sqrt(784);
b1=randn(ny,1);
w2=randn(10,ny)/sqrt(ny);
b2=randn(10,1);
for epo=1:epoch
for nums=1:numimages/mini_size
    for num=(nums-1)*mini_size+1:nums*mini_size
        x=images(:,num);
        y=labels(:,num);
    net2=w1*x;               %input of net2
    hidden=1./(1+exp(-net2-b1));%output of net2
    net3=w2*hidden;            %input of net3
    o=1./(1+exp(-net3-b2));%output of net3
    %back
    delta3=(y-o);%delta of net3   由于交叉熵代价函数的引入,偏导被消去
    delta2=w2'*delta3.*(hidden.*(1-hidden));%delta of net2
    %updata w/b
    w2=w2*(1-eta*numda/numimages)+eta*delta3*hidden'/mini_size;     %L2规范化
    w1=w1*(1-eta*numda/numimages)+eta*delta2*x'/mini_size;
    b2=b2+eta*delta3/mini_size;
    b1=b1+eta*delta2/mini_size;
    end
end
%calculate sum of error
%accuracy
sum0=0;
for i=1:1000
    x0=test_data_x(:,i);
    y0=test_data_y(:,i);
    a1=[];
    a2=[];
    a1=1./(1+exp(-w1*x0-b1));
    a2=1./(1+exp(-w2*a1-b2));
    [m1,n1]=max(a2);
    [m2,n2]=max(y0);
    if n1==n2
        sum0=sum0+1;
    end
    %e=o'-y;
    %sigma(num)=e'*e;
    sigma(i)=m2*log(m1)+(1-m2)*log(1-m1);   %计算代价cost
end
sigmas(epo)=-sum(sigma)/1000;       %cost求和
fprintf('epoch %d:%d/%d\n',epo,sum0,1000);
end
plot(sigmas);
xlabel('epoch');
ylabel('cost on the training_data');
end

好好学习,天天向上,话说都没有表情用,果然是程序猿的世界,我还是贴个表情吧

matlab处理手写识别问题的更多相关文章

  1. 基于MATLAB的手写公式识别(9)

    基于MATLAB的手写公式识别(9) 1.2图像的二值化 close all; clear all; Img=imread('drink.jpg'); %灰度化 Img_Gray=rgb2gray(I ...

  2. 基于MATLAB的手写公式识别(6)

    基于MATLAB的手写公式识别 2021-03-29 10:24:51 走通了程序,可以识别"心脑血管这几个字",还有很多不懂的地方. 2021-03-29 12:20:01 tw ...

  3. 基于MATLAB的手写公式识别(5)

    基于MATLAB的手写公式识别 总结一下昨天一天的工作成果: 获得了大致的识别过程. 一个图像从生肉到可以被处理需要经过预处理(灰质化.增加对比度.中值过滤.膨胀或腐蚀.闭环运算). 掌握了相关函数的 ...

  4. 基于MATLAB的手写公式识别(3)

    基于MATLAB的手写公式识别 图像的膨胀化,获取边缘(思考是否需要做这种处理,初始参考样本相对简单) %膨胀 imdilate(dilate=膨胀/扩大) clc clear A1=imread(' ...

  5. 基于MATLAB的手写公式识别(2)

    基于MATLAB的手写公式识别 图像的预处理(除去噪声.得到后续定位分割所需的信息.) 预处理其本质就是去除不需要的噪声信息,得到后续定位分割所需要的图像信息.图像信息在采集的过程中由于天气环境的影响 ...

  6. 基于MATLAB的手写公式识别(1)

    基于MATLAB的手写公式识别 reason:课程要求以及对MATLAB强大生命力的探索欲望: plan date:2021/3/28-2021/4/12 plan: 进行材料搜集和思路整理: 在已知 ...

  7. 【Win 10 应用开发】手写识别

    记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...

  8. JS / Egret 单笔手写识别、手势识别

    UnistrokeRecognizer 单笔手写识别.手势识别 UnistrokeRecognizer : https://github.com/RichLiu1023/UnistrokeRecogn ...

  9. (手写识别) Zinnia库及其实现方法研究

    Zinnia库及其实现方法研究 (转) zinnia是一个开源的手写识别库.采用C++实现.具有手写识别,学习以及文字模型数据制作转换等功能. 项目地址 [http://zinnia.sourcefo ...

随机推荐

  1. 体验Hadoop3.0生态圈-CDH6.1时代的来临

    体验Hadoop3.0生态圈-CDH6.1时代的来临 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 我在公司使用的是CDH5.15.1这个发行版本,具体的部署文档之前也有给大家分享 ...

  2. ElasticSearch的插件(Plugins)介绍

    ElasticSearch的插件(Plugins)介绍 作者:尹正杰  版权声明:原创作品,谢绝转载!否则将追究法律责任. 目前可以扩展ElasticSearch功能的插件有很多,比如:添加自定义的映 ...

  3. Centos6.6搭建Maven私服

    操作系统:Centos6.6 私服Ip:10.0.210.112 JDK:1.7 (已安装并配置好了环境变量) 1:上 传 nexus-2.11.2-03-bundle.tar.gz到/root/ne ...

  4. 关于react上线系列问题及解决方案

    近使用react做了一个音乐播放器小项目,在线下开发完成后,测试一切都没有问题,于是打算打包上线.首先注册了一个新浪云账号,然后创建了一个SAE应用实例,再然后就照着新浪云给出的远程仓库进行push. ...

  5. 立个Flag不学好PHP誓不罢休

    3年前从部队退伍退伍回来,就莫名其秒的爱上的编程,复学期间我几乎忘记了本专业的知识(原本我是读书籍设计的),从刚刚开始的C程序开始,一路走到一拿起书本我就几乎是睡着的状态,后来就开始了视频的学习之路, ...

  6. Emacs 快速指南(中文翻译)

      Emacs 快速指南 目录 1. 小结(SUMMARY) 2. 基本的光标控制(BASIC CURSOR CONTROL) 3. 如果 EMACS 失去响应(IF EMACS STOPS RESP ...

  7. python matplotlib 库学习

    基本使用 import matplotlib.pyplot as plt import numpy as np x = np.linspace(-1,1,50) y = 2*x+1 plt.figur ...

  8. mosh

    mosh 是一款使用 UDP 连接 C/S 的终端工具, 服务器只需安装好 mosh 套件, 并启动 SSH 服务, 等待 Client 连接即可. Client (mosh-client) 连接时, ...

  9. es6 javascript对象方法Object.assign() 对象的合并复制等

    Object.assign方法用于对象的合并,将源对象( source )的所有可枚举属性,复制到目标对象( target ). 详细使用稳步到前辈: http://blog.csdn.net/qq_ ...

  10. Maven的日常

    强烈建议把 Maven 的 settings.xml 文件同时放在:%USER_HOME%/.m2/settings.xml 和${maven.home}/conf/settings.xml 两个地方 ...