初学神经网络算法--梯度下降、反向传播、优化(交叉熵代价函数、L2规范化) 柔性最大值(softmax)还未领会其要义,之后再说

有点懒,暂时不想把算法重新总结,先贴一个之前做过的反向传播的总结ppt

其实python更好实现些,不过我想好好学matlab,就用matlab写了

然后是算法源码,第一个啰嗦些,不过可以帮助理解算法

  1. function bpback1(ny,eta,mini_size,epoch)
  2. %ny:隐藏层为1层,神经元数目为ny;eta:学习速率;mini_size:最小采样;eopch:迭代次数
  3. %该函数为梯度下降+反向传播
  4. %images
  5. [numimages,images]=bpimages('train-images.idx3-ubyte');
  6. [n_test,test_data_x]=bpimages('t10k-images.idx3-ubyte');
  7. %labels
  8. [numlabels,labels]=bplabels('train-labels.idx1-ubyte');
  9. [n_test,test_data_y]=bplabels('t10k-labels.idx1-ubyte');
  10. %init w/b
  11. %rand('state',sum(100*clock));
  12. %ny=30;eta=0.01;mini_size=10;
  13. w1=randn(ny,784);
  14. b1=randn(ny,1);
  15. w2=randn(10,ny);
  16. b2=randn(10,1);
  17. for epo=1:epoch
  18. for nums=1:numimages/mini_size
  19. for num=(nums-1)*mini_size+1:nums*mini_size
  20. x=images(:,num);
  21. y=labels(:,num);
  22. net2=w1*x; %input of net2
  23. for i=1:ny
  24. hidden(i)=1/(1+exp(-net2(i)-b1(i)));%output of net2
  25. end
  26. net3=w2*hidden'; %input of net3
  27. for i=1:10
  28. o(i)=1/(1+exp(-net3(i)-b2(i)));%output of net3
  29. end
  30.  
  31. %back
  32. for i=1:10
  33. delta3(i)=(y(i)-o(i))*o(i)*(1-o(i));%delta of net3
  34. end
  35. for i=1:ny
  36. delta2(i)=delta3*w2(:,i)*hidden(i)*(1-hidden(i));%delta of net2
  37. end
  38. %updata w/b
  39. for i=1:10
  40. for j=1:ny
  41. w2(i,j)=w2(i,j)+eta*delta3(i)*hidden(j)/mini_size;
  42. end
  43. end
  44. for i=1:ny
  45. for j=1:784
  46. w1(i,j)=w1(i,j)+eta*delta2(i)*x(j)/mini_size;
  47. end
  48. end
  49. for i=1:10
  50. b2(i)=b2(i)+eta*delta3(i);
  51. end
  52. for i=1:ny
  53. b1(i)=b1(i)+eta*delta2(i);
  54. end
  55. end
  56. end
  57. %calculate sum of error
  58. %accuracy
  59. sum0=0;
  60. for i=1:1000
  61. x0=test_data_x(:,i);
  62. y0=test_data_y(:,i);
  63. a1=[];
  64. a2=[];
  65. s1=w1*x0;
  66. for j=1:ny
  67. a1(j)=1/(1+exp(-s1(j)-b1(j)));
  68. end
  69. s2=w2*a1';
  70. for j=1:10
  71. a2(j)=1/(1+exp(-s2(j)-b2(j)));
  72. end
  73. a2=a2';
  74. [m1,n1]=max(a2);
  75. [m2,n2]=max(y0);
  76. if n1==n2
  77. sum0=sum0+1;
  78. end
  79. %e=o'-y;
  80. %sigma(num)=e'*e;
  81. sigma(i)=sumsqr(a2-y0); %代价为误差平方和
  82. end
  83. sigmas(epo)=sum(sigma)/(2*1000);
  84. fprintf('epoch %d:%d/%d\n',epo,sum0,1000);
  85. end
  86. plot(sigmas);
  87. xlabel('epoch');
  88. ylabel('cost on the training_data');
  89. end

  

  1. function bpback2(ny,eta,mini_size,epoch,numda)
  2. %ny:隐藏层为1层,神经元数目为ny;eta:学习速率;mini_size:最小采样;eopch:迭代次数
  3. %bpback的优化,包括L2规范化、交叉熵代价函数的引入---结果证明该优化非常赞!
  4. %images
  5. [numimages,images]=bpimages('train-images.idx3-ubyte');
  6. [n_test,test_data_x]=bpimages('t10k-images.idx3-ubyte');
  7. %labels
  8. [numlabels,labels]=bplabels('train-labels.idx1-ubyte');
  9. [n_test,test_data_y]=bplabels('t10k-labels.idx1-ubyte');
  10. %init w/b
  11. %ny=30;eta=0.05;mini_size=10;epoch=10;numda=0.1;
  12. rand('state',sum(100*clock));
  13. w1=randn(ny,784)/sqrt(784);
  14. b1=randn(ny,1);
  15. w2=randn(10,ny)/sqrt(ny);
  16. b2=randn(10,1);
  17. for epo=1:epoch
  18. for nums=1:numimages/mini_size
  19. for num=(nums-1)*mini_size+1:nums*mini_size
  20. x=images(:,num);
  21. y=labels(:,num);
  22. net2=w1*x; %input of net2
  23. hidden=1./(1+exp(-net2-b1));%output of net2
  24. net3=w2*hidden; %input of net3
  25. o=1./(1+exp(-net3-b2));%output of net3
  26. %back
  27. delta3=(y-o);%delta of net3 由于交叉熵代价函数的引入,偏导被消去
  28. delta2=w2'*delta3.*(hidden.*(1-hidden));%delta of net2
  29. %updata w/b
  30. w2=w2*(1-eta*numda/numimages)+eta*delta3*hidden'/mini_size; %L2规范化
  31. w1=w1*(1-eta*numda/numimages)+eta*delta2*x'/mini_size;
  32. b2=b2+eta*delta3/mini_size;
  33. b1=b1+eta*delta2/mini_size;
  34. end
  35. end
  36. %calculate sum of error
  37. %accuracy
  38. sum0=0;
  39. for i=1:1000
  40. x0=test_data_x(:,i);
  41. y0=test_data_y(:,i);
  42. a1=[];
  43. a2=[];
  44. a1=1./(1+exp(-w1*x0-b1));
  45. a2=1./(1+exp(-w2*a1-b2));
  46. [m1,n1]=max(a2);
  47. [m2,n2]=max(y0);
  48. if n1==n2
  49. sum0=sum0+1;
  50. end
  51. %e=o'-y;
  52. %sigma(num)=e'*e;
  53. sigma(i)=m2*log(m1)+(1-m2)*log(1-m1); %计算代价cost
  54. end
  55. sigmas(epo)=-sum(sigma)/1000; %cost求和
  56. fprintf('epoch %d:%d/%d\n',epo,sum0,1000);
  57. end
  58. plot(sigmas);
  59. xlabel('epoch');
  60. ylabel('cost on the training_data');
  61. end

好好学习,天天向上,话说都没有表情用,果然是程序猿的世界,我还是贴个表情吧

matlab处理手写识别问题的更多相关文章

  1. 基于MATLAB的手写公式识别(9)

    基于MATLAB的手写公式识别(9) 1.2图像的二值化 close all; clear all; Img=imread('drink.jpg'); %灰度化 Img_Gray=rgb2gray(I ...

  2. 基于MATLAB的手写公式识别(6)

    基于MATLAB的手写公式识别 2021-03-29 10:24:51 走通了程序,可以识别"心脑血管这几个字",还有很多不懂的地方. 2021-03-29 12:20:01 tw ...

  3. 基于MATLAB的手写公式识别(5)

    基于MATLAB的手写公式识别 总结一下昨天一天的工作成果: 获得了大致的识别过程. 一个图像从生肉到可以被处理需要经过预处理(灰质化.增加对比度.中值过滤.膨胀或腐蚀.闭环运算). 掌握了相关函数的 ...

  4. 基于MATLAB的手写公式识别(3)

    基于MATLAB的手写公式识别 图像的膨胀化,获取边缘(思考是否需要做这种处理,初始参考样本相对简单) %膨胀 imdilate(dilate=膨胀/扩大) clc clear A1=imread(' ...

  5. 基于MATLAB的手写公式识别(2)

    基于MATLAB的手写公式识别 图像的预处理(除去噪声.得到后续定位分割所需的信息.) 预处理其本质就是去除不需要的噪声信息,得到后续定位分割所需要的图像信息.图像信息在采集的过程中由于天气环境的影响 ...

  6. 基于MATLAB的手写公式识别(1)

    基于MATLAB的手写公式识别 reason:课程要求以及对MATLAB强大生命力的探索欲望: plan date:2021/3/28-2021/4/12 plan: 进行材料搜集和思路整理: 在已知 ...

  7. 【Win 10 应用开发】手写识别

    记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...

  8. JS / Egret 单笔手写识别、手势识别

    UnistrokeRecognizer 单笔手写识别.手势识别 UnistrokeRecognizer : https://github.com/RichLiu1023/UnistrokeRecogn ...

  9. (手写识别) Zinnia库及其实现方法研究

    Zinnia库及其实现方法研究 (转) zinnia是一个开源的手写识别库.采用C++实现.具有手写识别,学习以及文字模型数据制作转换等功能. 项目地址 [http://zinnia.sourcefo ...

随机推荐

  1. Spring Boot笔记五: Web开发之Webjar和静态资源映射规则

    目录 Webjar /** 访问当前项目的任何资源 欢迎页 标签页图标 Webjar 开始讲到Spring Boot的Web开发了,先介绍Webjar,这个其实就是把一些前端资源以jar包的形式导入到 ...

  2. springboot整合mybatis出现的一些问题

    springboot整合mybatis非常非常的简单,简直简单到发指.但是也有一些坑,这里我会详细的指出会遇到什么问题,并且这些配置的作用 整合mybatis,无疑需要mapper文件,实体类,dao ...

  3. css 绝对定位实现水平垂直居中

    负margin实现水平垂直居中 width: 500px; height: 500px; position: absolute; left: 50%; top :50%; margin-left: - ...

  4. 网络编程之Socket & ServerSocket

    网络编程之Socket & ServerSocket Socket:网络套接字,网络插座,建立网络通信连接至少要一对端口号(socket).socket本质是编程接口(API),对TCP/IP ...

  5. SpringBoot系列: 极简Demo程序和Tomcat war包部署

    =================================SpringBoot 标准项目创建步骤================================= 使用 Spring IDE( ...

  6. Devexpress dll搜集

    Devexpress一部分在全局dll中,需要分析缺哪些dll,有两种方式1.打包,安装时会自动提示 2.使用自带分析工具Assembly deployment tool

  7. [C++]2-3 倒三角形

    /* 倒三角形(Triangle) 输入正整数n<=20,输出一个n层的倒等腰三角形. 0 ######### 9 = 2* n-1 1 ####### 7 = 2*(n-1)-1 2 #### ...

  8. iFrame跨域解决办法

    按情境分1.不跨域时2.主域相同.子域不同时3.主域不同不跨域时访问iframe: contentWindow访问父级:parent访问顶级:top a.html <html xmlns=&qu ...

  9. Java -cp 命令行引用多个jar包的简单写法(Windows、Linux

    1.Windows下用法 在Windows上,可以使用 用法:java your-jar-lib-folder/* your-main-class your-jar-lib-folder为存放一堆ja ...

  10. springboot+freemarker

    springboot添加freemarker支持 1.application.properties中添加配置 #freemarker config spring.freemarker.allow-re ...