周老师的书,对神经网络写了一个小的Demo

是最简单的神经网络,只有一层的隐藏层。

这次练习依旧是对西瓜的好坏进行预测。

主要分了以下几个步骤

1、数据预处理

对西瓜的不同特性进行数学编码表示(0~1),我是直接编了对应数字。含糖量已经是一个0~1之间的数,所以就没有进行处理

青绿  1

乌黑 0.5

浅白  0

蜷缩  1

稍蜷 0.5

硬挺  0

浊响  1

沉闷 0.5

清脆  0

清晰  1

稍糊 0.5

模糊  0

凹陷  1

稍凹 0.5

平坦  0

硬滑  1

软黏  0

2、训练集和检测集

  1. package BP;
  2. public class TrainData {
  3. double[][] traindata;
  4. double[][] traindataoutput;
  5. double[][] testdata;
  6. double[][] testdataoutput;
  7. public TrainData(){
  8. traindata = new double[][]{
  9. new double[]{1,1,1,1,1,1,0.697,0.460},
  10. new double[]{0.5,1,0.5,1,1,1,0.774,0.376},
  11. new double[]{0.5,1,1,1,1,1,0.634,0.264},
  12. //new double[]{1,1,0.5,1,1,1,0.608,0.318,1},
  13. //new double[]{0,1,1,1,1,1,0.556,0.215,1},
  14. new double[]{1,0.5,1,1,0.5,0,0.403,0.237},
  15. new double[]{0.5,0.5,1,0.5,0.5,0,0.481,0.149},
  16. //new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211,1},
  17. //new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091,0},
  18. //new double[]{1,0,0,1,0,0,0.243,0.267,0},
  19. //new double[]{0,0,0,0,0,1,0.245,0.057,0},
  20. //new double[]{0,1,1,0,0,0,0.343,0.099,0},
  21. new double[]{1,0.5,1,0.5,1,1,0.639,0.161},
  22. new double[]{0,0.5,0,0.5,1,1,0.657,0.198},
  23. new double[]{0.5,0.5,1,1,0.5,0,0.360,0.370},
  24. new double[]{0,1,1,0,0,1,0.593,0.042},
  25. new double[]{1,1,0.5,0.5,0.5,1,0.719,0.103}
  26. };
  27. traindataoutput = new double[][]{
  28. new double[]{1},
  29. new double[]{1},
  30. new double[]{1},
  31. new double[]{1},
  32. new double[]{1},
  33. new double[]{0},
  34. new double[]{0},
  35. new double[]{0},
  36. new double[]{0},
  37. new double[]{0},
  38. };
  39. testdata = new double[][]{
  40. new double[]{1,1,0.5,1,1,1,0.608,0.318},
  41. new double[]{0,1,1,1,1,1,0.556,0.215},
  42. new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211},
  43. new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091},
  44. new double[]{1,0,0,1,0,0,0.243,0.267},
  45. new double[]{0,0,0,0,0,1,0.245,0.057},
  46. new double[]{0,1,1,0,0,0,0.343,0.099},
  47. };
  48. testdataoutput = new double[][]{
  49. new double[]{1},
  50. new double[]{1},
  51. new double[]{1},
  52. new double[]{0},
  53. new double[]{0},
  54. new double[]{0},
  55. new double[]{0},
  56. };
  57. }
  58. public static void main(String[] args){
  59. TrainData t = new TrainData();
  60. for(int i=0;i<t.traindata.length;i++){
  61. for(int j=0;j<9;j++)
  62. System.out.print(t.traindata[i][j]+ " ");
  63. System.out.println();
  64. }
  65. }
  66. }

3、BP主函数

  1. package BP;
  2. import java.util.Random;
  3. public class BP {
  4. int innum;
  5. int hiddennum;
  6. int outnum;
  7. //输入、隐藏、输出层
  8. public double[] input;
  9. public double[] hidden;
  10. //output为本神经网络计算出的输出值
  11. public double[] output;
  12. //realoutput为训练网络时,用户提供的真的输出值
  13. public double[] realoutput;
  14. //v[i,j]表示输入层i到隐层j  w[i,j]表示隐层i到输出层j
  15. public double[][] v;
  16. public double[][] w;
  17. //beta为隐层的阈值,afa为输出层阈值
  18. public double[] beta;
  19. public double[] afa;
  20. //学习率
  21. public double eta;
  22. //步长
  23. public double momentum;
  24. public final Random random;
  25. public BP(int inputnum,int hiddennum,int outputnum,double learningrate){
  26. innum = inputnum;
  27. this.hiddennum = hiddennum;
  28. outnum = outputnum;
  29. input = new double[inputnum + 1];
  30. hidden = new double[hiddennum + 1];
  31. output = new double[outputnum + 1];
  32. realoutput = new double[outputnum + 1];
  33. v = new double[inputnum + 1][hiddennum + 1];
  34. w = new double[hiddennum + 1][outputnum + 1];
  35. beta = new double[outputnum + 1];
  36. afa = new double[hiddennum + 1];
  37. for(int i=0;i<outputnum;i++)
  38. beta[i] = 0.0;
  39. for(int i=0;i<hiddennum;i++)
  40. afa[i] = 0.0;
  41. eta = learningrate;
  42. //随机数对结果影响较大
  43. random = new Random(19950326);
  44. randomizeWeights(w);
  45. randomizeWeights(v);
  46. }
  47. public void testData(double[] in){
  48. input = in;
  49. getNetOutput();
  50. }
  51. //只对本题目有用,output>0.5时为好西瓜,output<0.5时为坏西瓜
  52. public int predict(double[] in){
  53. testData(in);
  54. if(output[0]>0.5)
  55. return 1;
  56. else
  57. return 0;
  58. }
  59. //获得在test集上的正确率
  60. public double getAccuracy(double[][] in,double[][] out){
  61. int rightans = 0,wrongans = 0;
  62. for(int i=0;i<in.length;i++){
  63. if(predict(in[i])==(out[i][0])){
  64. //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);
  65. rightans++;
  66. }else{
  67. //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);
  68. wrongans++;
  69. }
  70. }
  71. System.out.println("对:"+rightans+" 错:"+wrongans);
  72. return (double)rightans/(double)(rightans+wrongans);
  73. }
  74. //times为进行几轮训练
  75. public void train(int times){
  76. TrainData t = new TrainData();
  77. double wu = 0.0,acc = 0.0;
  78. int n = t.traindata.length;
  79. for(int i=0;i<times;i++){
  80. wu = 0.0;
  81. for(int j=0;j<n;j++){
  82. traindata(t.traindata[j],t.traindataoutput[j]);
  83. wu += getDeviation();
  84. }
  85. wu = wu/((double)n);
  86. System.out.println("第"+i+"轮训练:"+wu);
  87. acc = getAccuracy(t.testdata,t.testdataoutput);
  88. System.out.println("预测正确率为: "+acc);
  89. }
  90. }
  91. //对一个input输入进行训练
  92. public void traindata(double[] in,double[] out){
  93. input = in;
  94. realoutput = out;
  95. getNetOutput();
  96. adjustParameter();
  97. }
  98. //获得误差E
  99. public double getDeviation(){
  100. double e = 0.0;
  101. for(int i=0;i<outnum;i++)
  102. e += (output[i] - realoutput[i])*(output[i] - realoutput[i]);
  103. e *= 0.5;
  104. return e;
  105. }
  106. //调整权值
  107. public void adjustParameter(){
  108. double g[],e = 0.0;
  109. g = new double[outnum];
  110. int i,j;
  111. for(i=0;i<outnum;i++){
  112. g[i] = output[i]*(1-output[i])*(realoutput[i]-output[i]);
  113. beta[i] -= eta * g[i];
  114. for(j=0;j<hiddennum;j++){
  115. w[j][i] += eta * g[i] * hidden[j];
  116. }
  117. }
  118. for(i=0;i<hiddennum;i++){
  119. e = 0.0;
  120. for(j=0;j<outnum;j++)
  121. e += g[j]*w[i][j];
  122. e = hidden[i]*(1-hidden[i])*e;
  123. afa[i] -= eta * e;
  124. for(j=0;j<innum;j++)
  125. v[j][i] += eta * e * input[j];
  126. }
  127. }
  128. //获得output
  129. public void getNetOutput(){
  130. int i,j;
  131. double tmp=0.0;
  132. for(i=0;i<hiddennum;i++){
  133. tmp = 0.0;
  134. for(j=0;j<innum;j++)
  135. tmp += v[j][i]*input[j];
  136. hidden[i] = sigmoid(tmp-afa[i]);
  137. }
  138. for(i=0;i<outnum;i++){
  139. tmp = 0.0;
  140. for(j=0;j<hiddennum;j++)
  141. tmp += w[j][i]*hidden[j];
  142. output[i] = sigmoid(tmp-beta[i]);
  143. }
  144. }
  145. //对权值矩阵w、v进行初始随机化
  146. private void randomizeWeights(double[][] matrix) {
  147. for (int i = 0, len = matrix.length; i != len; i++)
  148. for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
  149. double real = random.nextDouble();
  150. matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
  151. }
  152. }
  153. public void debug(){
  154. System.out.println("========begin=======");
  155. for(int i=0;i<innum;i++){
  156. for(int j=0;j<hiddennum;j++)
  157. System.out.print(v[i][j]+" ");
  158. System.out.println();
  159. }
  160. System.out.println();
  161. for(int i=0;i<hiddennum;i++){
  162. for(int j=0;j<outnum;j++)
  163. System.out.print(w[i][j]+" ");
  164. System.out.println();
  165. }
  166. System.out.println("========end=======");
  167. }
  168. public double sigmoid(double z){
  169. double s = 0.0;
  170. s = 1d/(1d + Math.exp(-z));
  171. return s;
  172. }
  173. public static void main(String[] args){
  174. BP bp = new BP(8,10,1,0.1);
  175. bp.train(50);
  176. }
  177. }

我要说的:

就结果来说,在验证集上的正确率可达到85%,当然很大程度上取决于BP初始化时random函数的种子。运气好的时候甚至能达到100%的正确率,运气不好的时候只有40%多,跟随便乱猜没什么区别。

想问大神。。。只能采用这种随机算法来找到一个最合适的ramdom种子值嘛?能不能用遗传这样的开放式算法进行搜索来找到最合适的随机值(我觉得随机的种子和随机结果并没有什么直接的关联,所以不知道能不能用遗传算法之列。。。)

机器学习 demo分西瓜的更多相关文章

  1. 分西瓜(DFS)

    描述今天是阴历七月初五,acm队员zb的生日.zb正在和C小加.never在武汉集训.他想给这两位兄弟买点什么庆祝生日,经过调查,zb发现C小加和never都很喜欢吃西瓜,而且一吃就是一堆的那种,zb ...

  2. LASSO回归与L1正则化 西瓜书

    LASSO回归与L1正则化 西瓜书 2018年04月23日 19:29:57 BIT_666 阅读数 2968更多 分类专栏: 机器学习 机器学习数学原理 西瓜书   版权声明:本文为博主原创文章,遵 ...

  3. 131.003 数据预处理之Dummy Variable & One-Hot Encoding

    @(131 - Machine Learning | 机器学习) Demo 直观来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制 {sex:{male, female}}​ ...

  4. CUDA程序设计(一)

    为什么需要GPU 几年前我启动并主导了一个项目,当时还在谷歌,这个项目叫谷歌大脑.该项目利用谷歌的计算基础设施来构建神经网络. 规模大概比之前的神经网络扩大了一百倍,我们的方法是用约一千台电脑.这确实 ...

  5. ios基础篇(二十五)—— Animation动画(UIView、CoreAnimation)

    Animation主要分为两类: 1.UIView属性动画 2.CoreAnimation动画 一.UIView属性动画 UIKit直接将动画集成到UIView类中,实现简单动画的创建过程.UIVie ...

  6. NY 325 zb的生日

    假设所有西瓜重 Asum,所求的是用 Asum / 2 的背包装,最多装下多少. 刚开始用贪心作的,WA.后来用01背包,结果TLE,数据太大.原来用的是深搜! dfs(int sum, int i) ...

  7. backbone.Router History源码笔记

    Backbone.History和Backbone.Router history和router都是控制路由的,做一个单页应用,要控制前进后退,就可以用到他们了. History类用于监听URL的变化, ...

  8. spring springMVC mybatis 集成

    最近闲来无事,整理了一下spring springMVC mybatis 集成,关于这个话题在园子里已经有很多人写过了,我主要是想提供一个完整的demo,涵盖crud,事物控制等. 整个demo分三个 ...

  9. iOS百度推送的基本使用

    一.iOS证书指导 在 iOS App 中加入消息推送功能时,必须要在 Apple 的开发者中心网站上申请推送证书,每一个 App 需要申请两个证书,一个在开发测试环境下使用,另一个用于上线到 App ...

随机推荐

  1. Unity预计算全局光照的学习(速度优化,LightProbe,LPPV)

    1.基本参数与使用 1.1 常规介绍 使用预计算光照需要在Window/Lighting面板下找到预计算光照选项,保持勾选预计算光照并保证场景中有一个光照静态的物体 此时在编辑器内构建后,预计算光照开 ...

  2. IOS解惑(1)之@property(nonatomic,getter=isOn) BOOL on;中的getter解惑

    1 问题: @property(nonatomic,getter=isOn) BOOL on; 中的getter = isOn的含义? 2 答案: 如果这个property是 BOOL on, 那么O ...

  3. SQL作业

    USE [test] GO /****** Object: StoredProcedure [dbo].[wangchuang] Script Date: 2016/8/25 14:09:24 *** ...

  4. Checkbox: ListView 与CheckBox 触发事件冲突的问题

    我相信很多人都遇到过 ListView 中放入checkBox ,会导致ListView的OnItemClickListener无效,这是怎么回事呢? 这是因为checkBox 的点击事件的优先级比L ...

  5. Market Guide for AIOps Platforms

    AIOps platforms enhance IT operations through greater insights by combining big data, machine learni ...

  6. .NET MVC5+ Dapper+扩展+微软Unity依赖注入实例

    1.dapper和dapper扩展需要在线安装或者引用DLL即可 使用nuget为项目增加Unity相关的包 2.model类 public class UserInfo { public int I ...

  7. 每日英语:The Most Destructive, Unpredictable Force in Tech

    What's the most destructive force in the tech world, the thing that has nearly killed BlackBerry, pu ...

  8. 关于ddx/ddy重建法线在edge边沿上的artifacts问题

    经验证,原来ddx/ddy这两个操作,在forward rendering与deferred rendering中存在着微妙的应用区别. 在forward rendering中,GPU shader会 ...

  9. 使用windowAnimations定义Activity及Dialog的进入退出效果

    看了android的源代码和资源文件,终于明白如何去修改设置Dialog和Activity的进入和退出效果了.设置Dialog首先通过getWindow()方法获取它的窗口,然后通过getAttrib ...

  10. Spark中groupBy groupByKey reduceByKey的区别

    groupBy 和SQL中groupby一样,只是后面必须结合聚合函数使用才可以. 例如: hour.filter($"version".isin(version: _*)).gr ...