初学神经网络算法--梯度下降、反向传播、优化(交叉熵代价函数、L2规范化) 柔性最大值(softmax)还未领会其要义,之后再说

有点懒,暂时不想把算法重新总结,先贴一个之前做过的反向传播的总结ppt

其实python更好实现些,不过我想好好学matlab,就用matlab写了

然后是算法源码,第一个啰嗦些,不过可以帮助理解算法

function bpback1(ny,eta,mini_size,epoch)
%ny:隐藏层为1层,神经元数目为ny;eta:学习速率;mini_size:最小采样;eopch:迭代次数
%该函数为梯度下降+反向传播
%images
[numimages,images]=bpimages('train-images.idx3-ubyte');
[n_test,test_data_x]=bpimages('t10k-images.idx3-ubyte');
%labels
[numlabels,labels]=bplabels('train-labels.idx1-ubyte');
[n_test,test_data_y]=bplabels('t10k-labels.idx1-ubyte');
%init w/b
%rand('state',sum(100*clock));
%ny=30;eta=0.01;mini_size=10;
w1=randn(ny,784);
b1=randn(ny,1);
w2=randn(10,ny);
b2=randn(10,1);
for epo=1:epoch
for nums=1:numimages/mini_size
    for num=(nums-1)*mini_size+1:nums*mini_size
        x=images(:,num);
        y=labels(:,num);
    net2=w1*x;               %input of net2
    for i=1:ny
    hidden(i)=1/(1+exp(-net2(i)-b1(i)));%output of net2
    end
    net3=w2*hidden';            %input of net3
    for i=1:10
    o(i)=1/(1+exp(-net3(i)-b2(i)));%output of net3
    end

    %back
    for i=1:10
    delta3(i)=(y(i)-o(i))*o(i)*(1-o(i));%delta of net3
    end
    for i=1:ny
    delta2(i)=delta3*w2(:,i)*hidden(i)*(1-hidden(i));%delta of net2
    end
    %updata w/b
    for i=1:10
        for j=1:ny
    w2(i,j)=w2(i,j)+eta*delta3(i)*hidden(j)/mini_size;
        end
    end
    for i=1:ny
        for j=1:784
    w1(i,j)=w1(i,j)+eta*delta2(i)*x(j)/mini_size;
        end
    end
    for i=1:10
    b2(i)=b2(i)+eta*delta3(i);
    end
    for i=1:ny
    b1(i)=b1(i)+eta*delta2(i);
    end
    end
end
%calculate sum of error
%accuracy
sum0=0;
for i=1:1000
    x0=test_data_x(:,i);
    y0=test_data_y(:,i);
    a1=[];
    a2=[];
    s1=w1*x0;
    for j=1:ny
    a1(j)=1/(1+exp(-s1(j)-b1(j)));
    end
    s2=w2*a1';
    for j=1:10
    a2(j)=1/(1+exp(-s2(j)-b2(j)));
    end
    a2=a2';
    [m1,n1]=max(a2);
    [m2,n2]=max(y0);
    if n1==n2
        sum0=sum0+1;
    end
    %e=o'-y;
    %sigma(num)=e'*e;
    sigma(i)=sumsqr(a2-y0);   %代价为误差平方和
end
sigmas(epo)=sum(sigma)/(2*1000);
fprintf('epoch %d:%d/%d\n',epo,sum0,1000);
end
plot(sigmas);
xlabel('epoch');
ylabel('cost on the training_data');
end

  

function bpback2(ny,eta,mini_size,epoch,numda)
%ny:隐藏层为1层,神经元数目为ny;eta:学习速率;mini_size:最小采样;eopch:迭代次数
%bpback的优化,包括L2规范化、交叉熵代价函数的引入---结果证明该优化非常赞!
%images
[numimages,images]=bpimages('train-images.idx3-ubyte');
[n_test,test_data_x]=bpimages('t10k-images.idx3-ubyte');
%labels
[numlabels,labels]=bplabels('train-labels.idx1-ubyte');
[n_test,test_data_y]=bplabels('t10k-labels.idx1-ubyte');
%init w/b
%ny=30;eta=0.05;mini_size=10;epoch=10;numda=0.1;
rand('state',sum(100*clock));
w1=randn(ny,784)/sqrt(784);
b1=randn(ny,1);
w2=randn(10,ny)/sqrt(ny);
b2=randn(10,1);
for epo=1:epoch
for nums=1:numimages/mini_size
    for num=(nums-1)*mini_size+1:nums*mini_size
        x=images(:,num);
        y=labels(:,num);
    net2=w1*x;               %input of net2
    hidden=1./(1+exp(-net2-b1));%output of net2
    net3=w2*hidden;            %input of net3
    o=1./(1+exp(-net3-b2));%output of net3
    %back
    delta3=(y-o);%delta of net3   由于交叉熵代价函数的引入,偏导被消去
    delta2=w2'*delta3.*(hidden.*(1-hidden));%delta of net2
    %updata w/b
    w2=w2*(1-eta*numda/numimages)+eta*delta3*hidden'/mini_size;     %L2规范化
    w1=w1*(1-eta*numda/numimages)+eta*delta2*x'/mini_size;
    b2=b2+eta*delta3/mini_size;
    b1=b1+eta*delta2/mini_size;
    end
end
%calculate sum of error
%accuracy
sum0=0;
for i=1:1000
    x0=test_data_x(:,i);
    y0=test_data_y(:,i);
    a1=[];
    a2=[];
    a1=1./(1+exp(-w1*x0-b1));
    a2=1./(1+exp(-w2*a1-b2));
    [m1,n1]=max(a2);
    [m2,n2]=max(y0);
    if n1==n2
        sum0=sum0+1;
    end
    %e=o'-y;
    %sigma(num)=e'*e;
    sigma(i)=m2*log(m1)+(1-m2)*log(1-m1);   %计算代价cost
end
sigmas(epo)=-sum(sigma)/1000;       %cost求和
fprintf('epoch %d:%d/%d\n',epo,sum0,1000);
end
plot(sigmas);
xlabel('epoch');
ylabel('cost on the training_data');
end

好好学习,天天向上,话说都没有表情用,果然是程序猿的世界,我还是贴个表情吧

matlab处理手写识别问题的更多相关文章

  1. 基于MATLAB的手写公式识别(9)

    基于MATLAB的手写公式识别(9) 1.2图像的二值化 close all; clear all; Img=imread('drink.jpg'); %灰度化 Img_Gray=rgb2gray(I ...

  2. 基于MATLAB的手写公式识别(6)

    基于MATLAB的手写公式识别 2021-03-29 10:24:51 走通了程序,可以识别"心脑血管这几个字",还有很多不懂的地方. 2021-03-29 12:20:01 tw ...

  3. 基于MATLAB的手写公式识别(5)

    基于MATLAB的手写公式识别 总结一下昨天一天的工作成果: 获得了大致的识别过程. 一个图像从生肉到可以被处理需要经过预处理(灰质化.增加对比度.中值过滤.膨胀或腐蚀.闭环运算). 掌握了相关函数的 ...

  4. 基于MATLAB的手写公式识别(3)

    基于MATLAB的手写公式识别 图像的膨胀化,获取边缘(思考是否需要做这种处理,初始参考样本相对简单) %膨胀 imdilate(dilate=膨胀/扩大) clc clear A1=imread(' ...

  5. 基于MATLAB的手写公式识别(2)

    基于MATLAB的手写公式识别 图像的预处理(除去噪声.得到后续定位分割所需的信息.) 预处理其本质就是去除不需要的噪声信息,得到后续定位分割所需要的图像信息.图像信息在采集的过程中由于天气环境的影响 ...

  6. 基于MATLAB的手写公式识别(1)

    基于MATLAB的手写公式识别 reason:课程要求以及对MATLAB强大生命力的探索欲望: plan date:2021/3/28-2021/4/12 plan: 进行材料搜集和思路整理: 在已知 ...

  7. 【Win 10 应用开发】手写识别

    记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...

  8. JS / Egret 单笔手写识别、手势识别

    UnistrokeRecognizer 单笔手写识别.手势识别 UnistrokeRecognizer : https://github.com/RichLiu1023/UnistrokeRecogn ...

  9. (手写识别) Zinnia库及其实现方法研究

    Zinnia库及其实现方法研究 (转) zinnia是一个开源的手写识别库.采用C++实现.具有手写识别,学习以及文字模型数据制作转换等功能. 项目地址 [http://zinnia.sourcefo ...

随机推荐

  1. Java Web之Tomcat

    Tomcat的下载安装配置什么的,百度一大把.现在介绍一下Tomcat的文件夹目录结构. 浏览器访问127.0.0.1:8080 出现Tomcat页面即表示安装成功. 这个就是Tomcat的目录了 b ...

  2. python mysql 视图 触发器 事物 存储过程 用户授权 数据备份还原

    ###################总结########### 视图是一个虚拟表(非真实存在) 是跑在内存中的表,真实表是在硬盘上的表 使用视图我们可以把查询过程中的临时表摘出来,保存下来,用视图去 ...

  3. HDU 1046(最短路径 **)

    题意是要在一个矩形点阵中求能从一点出发遍历所有点再回到起始点的最短路径长度. 不需要用到搜索什么的,可以走一个“梳子型”即可完成最短路径,而情况可以被分成如下两种: 一.矩形的长或宽中有偶数,则可以走 ...

  4. Linux 三剑客 -- awk sed grep

    本文由本人收集整理自互联网供自己与网友参考,参考文章均已列出,如有侵权,请告知! 顶配awk,中配sed,标配grep awk 参考 sed 参考 grep 参考 在线查看linux命令速记表 app ...

  5. 未启用当前数据库的 SQL Server Service Broker,因此查询通知不受支持。如果希望使用通知,请为此数据库启用 Service Broker

    昨晚遇到的这个问题,也知道Notifications service依赖底层的Service broker的.本以为只需要执行以下脚本对数据库启用Service broker即可. alter dat ...

  6. Android几个比较有用的插件

    1.Android  Drawable Importer 2.Android ButterKnife Zelezny 使用方法,在SetContentView上右键,Generate 3.Androi ...

  7. Newtonsoft.Json 的基本用法

    Ø  前言 说起 C# 对 JSON 的操作(序列化与反序列化),大家都会想到 JavaScriptSerializer.DataContractJsonSerializer 与 Newtonsoft ...

  8. Hibernate常用API以及使用说明

    1===>Hibernate常用的aip有Configuration,SessionFactory,Transaction,Session Configuration主要用于加载配置文件,使用 ...

  9. 四、文件IO——内核数据结构和原子操作

    4.1 缓存 buff 说明 一般设置缓存 buff  的大小是由一定的规律的,就是根据磁盘块的大小来定. Linux下输入命令: df -k  查看磁盘 可以用命令查看下 /dev/sda1 磁盘的 ...

  10. gcd 二进制/循环

    #include<bits/stdc++.h> #define LL long long using namespace std; inline aabs(LL x){ ?x:-x;} i ...