机器学习 demo分西瓜
周老师的书,对神经网络写了一个小的Demo
是最简单的神经网络,只有一层的隐藏层。
这次练习依旧是对西瓜的好坏进行预测。
主要分了以下几个步骤
1、数据预处理
对西瓜的不同特性进行数学编码表示(0~1),我是直接编了对应数字。含糖量已经是一个0~1之间的数,所以就没有进行处理
青绿 1
乌黑 0.5
浅白 0
蜷缩 1
稍蜷 0.5
硬挺 0
浊响 1
沉闷 0.5
清脆 0
清晰 1
稍糊 0.5
模糊 0
凹陷 1
稍凹 0.5
平坦 0
硬滑 1
软黏 0
2、训练集和检测集
- package BP;
- public class TrainData {
- double[][] traindata;
- double[][] traindataoutput;
- double[][] testdata;
- double[][] testdataoutput;
- public TrainData(){
- traindata = new double[][]{
- new double[]{1,1,1,1,1,1,0.697,0.460},
- new double[]{0.5,1,0.5,1,1,1,0.774,0.376},
- new double[]{0.5,1,1,1,1,1,0.634,0.264},
- //new double[]{1,1,0.5,1,1,1,0.608,0.318,1},
- //new double[]{0,1,1,1,1,1,0.556,0.215,1},
- new double[]{1,0.5,1,1,0.5,0,0.403,0.237},
- new double[]{0.5,0.5,1,0.5,0.5,0,0.481,0.149},
- //new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211,1},
- //new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091,0},
- //new double[]{1,0,0,1,0,0,0.243,0.267,0},
- //new double[]{0,0,0,0,0,1,0.245,0.057,0},
- //new double[]{0,1,1,0,0,0,0.343,0.099,0},
- new double[]{1,0.5,1,0.5,1,1,0.639,0.161},
- new double[]{0,0.5,0,0.5,1,1,0.657,0.198},
- new double[]{0.5,0.5,1,1,0.5,0,0.360,0.370},
- new double[]{0,1,1,0,0,1,0.593,0.042},
- new double[]{1,1,0.5,0.5,0.5,1,0.719,0.103}
- };
- traindataoutput = new double[][]{
- new double[]{1},
- new double[]{1},
- new double[]{1},
- new double[]{1},
- new double[]{1},
- new double[]{0},
- new double[]{0},
- new double[]{0},
- new double[]{0},
- new double[]{0},
- };
- testdata = new double[][]{
- new double[]{1,1,0.5,1,1,1,0.608,0.318},
- new double[]{0,1,1,1,1,1,0.556,0.215},
- new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211},
- new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091},
- new double[]{1,0,0,1,0,0,0.243,0.267},
- new double[]{0,0,0,0,0,1,0.245,0.057},
- new double[]{0,1,1,0,0,0,0.343,0.099},
- };
- testdataoutput = new double[][]{
- new double[]{1},
- new double[]{1},
- new double[]{1},
- new double[]{0},
- new double[]{0},
- new double[]{0},
- new double[]{0},
- };
- }
- public static void main(String[] args){
- TrainData t = new TrainData();
- for(int i=0;i<t.traindata.length;i++){
- for(int j=0;j<9;j++)
- System.out.print(t.traindata[i][j]+ " ");
- System.out.println();
- }
- }
- }
3、BP主函数
- package BP;
- import java.util.Random;
- public class BP {
- int innum;
- int hiddennum;
- int outnum;
- //输入、隐藏、输出层
- public double[] input;
- public double[] hidden;
- //output为本神经网络计算出的输出值
- public double[] output;
- //realoutput为训练网络时,用户提供的真的输出值
- public double[] realoutput;
- //v[i,j]表示输入层i到隐层j w[i,j]表示隐层i到输出层j
- public double[][] v;
- public double[][] w;
- //beta为隐层的阈值,afa为输出层阈值
- public double[] beta;
- public double[] afa;
- //学习率
- public double eta;
- //步长
- public double momentum;
- public final Random random;
- public BP(int inputnum,int hiddennum,int outputnum,double learningrate){
- innum = inputnum;
- this.hiddennum = hiddennum;
- outnum = outputnum;
- input = new double[inputnum + 1];
- hidden = new double[hiddennum + 1];
- output = new double[outputnum + 1];
- realoutput = new double[outputnum + 1];
- v = new double[inputnum + 1][hiddennum + 1];
- w = new double[hiddennum + 1][outputnum + 1];
- beta = new double[outputnum + 1];
- afa = new double[hiddennum + 1];
- for(int i=0;i<outputnum;i++)
- beta[i] = 0.0;
- for(int i=0;i<hiddennum;i++)
- afa[i] = 0.0;
- eta = learningrate;
- //随机数对结果影响较大
- random = new Random(19950326);
- randomizeWeights(w);
- randomizeWeights(v);
- }
- public void testData(double[] in){
- input = in;
- getNetOutput();
- }
- //只对本题目有用,output>0.5时为好西瓜,output<0.5时为坏西瓜
- public int predict(double[] in){
- testData(in);
- if(output[0]>0.5)
- return 1;
- else
- return 0;
- }
- //获得在test集上的正确率
- public double getAccuracy(double[][] in,double[][] out){
- int rightans = 0,wrongans = 0;
- for(int i=0;i<in.length;i++){
- if(predict(in[i])==(out[i][0])){
- //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);
- rightans++;
- }else{
- //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);
- wrongans++;
- }
- }
- System.out.println("对:"+rightans+" 错:"+wrongans);
- return (double)rightans/(double)(rightans+wrongans);
- }
- //times为进行几轮训练
- public void train(int times){
- TrainData t = new TrainData();
- double wu = 0.0,acc = 0.0;
- int n = t.traindata.length;
- for(int i=0;i<times;i++){
- wu = 0.0;
- for(int j=0;j<n;j++){
- traindata(t.traindata[j],t.traindataoutput[j]);
- wu += getDeviation();
- }
- wu = wu/((double)n);
- System.out.println("第"+i+"轮训练:"+wu);
- acc = getAccuracy(t.testdata,t.testdataoutput);
- System.out.println("预测正确率为: "+acc);
- }
- }
- //对一个input输入进行训练
- public void traindata(double[] in,double[] out){
- input = in;
- realoutput = out;
- getNetOutput();
- adjustParameter();
- }
- //获得误差E
- public double getDeviation(){
- double e = 0.0;
- for(int i=0;i<outnum;i++)
- e += (output[i] - realoutput[i])*(output[i] - realoutput[i]);
- e *= 0.5;
- return e;
- }
- //调整权值
- public void adjustParameter(){
- double g[],e = 0.0;
- g = new double[outnum];
- int i,j;
- for(i=0;i<outnum;i++){
- g[i] = output[i]*(1-output[i])*(realoutput[i]-output[i]);
- beta[i] -= eta * g[i];
- for(j=0;j<hiddennum;j++){
- w[j][i] += eta * g[i] * hidden[j];
- }
- }
- for(i=0;i<hiddennum;i++){
- e = 0.0;
- for(j=0;j<outnum;j++)
- e += g[j]*w[i][j];
- e = hidden[i]*(1-hidden[i])*e;
- afa[i] -= eta * e;
- for(j=0;j<innum;j++)
- v[j][i] += eta * e * input[j];
- }
- }
- //获得output
- public void getNetOutput(){
- int i,j;
- double tmp=0.0;
- for(i=0;i<hiddennum;i++){
- tmp = 0.0;
- for(j=0;j<innum;j++)
- tmp += v[j][i]*input[j];
- hidden[i] = sigmoid(tmp-afa[i]);
- }
- for(i=0;i<outnum;i++){
- tmp = 0.0;
- for(j=0;j<hiddennum;j++)
- tmp += w[j][i]*hidden[j];
- output[i] = sigmoid(tmp-beta[i]);
- }
- }
- //对权值矩阵w、v进行初始随机化
- private void randomizeWeights(double[][] matrix) {
- for (int i = 0, len = matrix.length; i != len; i++)
- for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
- double real = random.nextDouble();
- matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
- }
- }
- public void debug(){
- System.out.println("========begin=======");
- for(int i=0;i<innum;i++){
- for(int j=0;j<hiddennum;j++)
- System.out.print(v[i][j]+" ");
- System.out.println();
- }
- System.out.println();
- for(int i=0;i<hiddennum;i++){
- for(int j=0;j<outnum;j++)
- System.out.print(w[i][j]+" ");
- System.out.println();
- }
- System.out.println("========end=======");
- }
- public double sigmoid(double z){
- double s = 0.0;
- s = 1d/(1d + Math.exp(-z));
- return s;
- }
- public static void main(String[] args){
- BP bp = new BP(8,10,1,0.1);
- bp.train(50);
- }
- }
我要说的:
就结果来说,在验证集上的正确率可达到85%,当然很大程度上取决于BP初始化时random函数的种子。运气好的时候甚至能达到100%的正确率,运气不好的时候只有40%多,跟随便乱猜没什么区别。
想问大神。。。只能采用这种随机算法来找到一个最合适的ramdom种子值嘛?能不能用遗传这样的开放式算法进行搜索来找到最合适的随机值(我觉得随机的种子和随机结果并没有什么直接的关联,所以不知道能不能用遗传算法之列。。。)
机器学习 demo分西瓜的更多相关文章
- 分西瓜(DFS)
描述今天是阴历七月初五,acm队员zb的生日.zb正在和C小加.never在武汉集训.他想给这两位兄弟买点什么庆祝生日,经过调查,zb发现C小加和never都很喜欢吃西瓜,而且一吃就是一堆的那种,zb ...
- LASSO回归与L1正则化 西瓜书
LASSO回归与L1正则化 西瓜书 2018年04月23日 19:29:57 BIT_666 阅读数 2968更多 分类专栏: 机器学习 机器学习数学原理 西瓜书 版权声明:本文为博主原创文章,遵 ...
- 131.003 数据预处理之Dummy Variable & One-Hot Encoding
@(131 - Machine Learning | 机器学习) Demo 直观来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制 {sex:{male, female}} ...
- CUDA程序设计(一)
为什么需要GPU 几年前我启动并主导了一个项目,当时还在谷歌,这个项目叫谷歌大脑.该项目利用谷歌的计算基础设施来构建神经网络. 规模大概比之前的神经网络扩大了一百倍,我们的方法是用约一千台电脑.这确实 ...
- ios基础篇(二十五)—— Animation动画(UIView、CoreAnimation)
Animation主要分为两类: 1.UIView属性动画 2.CoreAnimation动画 一.UIView属性动画 UIKit直接将动画集成到UIView类中,实现简单动画的创建过程.UIVie ...
- NY 325 zb的生日
假设所有西瓜重 Asum,所求的是用 Asum / 2 的背包装,最多装下多少. 刚开始用贪心作的,WA.后来用01背包,结果TLE,数据太大.原来用的是深搜! dfs(int sum, int i) ...
- backbone.Router History源码笔记
Backbone.History和Backbone.Router history和router都是控制路由的,做一个单页应用,要控制前进后退,就可以用到他们了. History类用于监听URL的变化, ...
- spring springMVC mybatis 集成
最近闲来无事,整理了一下spring springMVC mybatis 集成,关于这个话题在园子里已经有很多人写过了,我主要是想提供一个完整的demo,涵盖crud,事物控制等. 整个demo分三个 ...
- iOS百度推送的基本使用
一.iOS证书指导 在 iOS App 中加入消息推送功能时,必须要在 Apple 的开发者中心网站上申请推送证书,每一个 App 需要申请两个证书,一个在开发测试环境下使用,另一个用于上线到 App ...
随机推荐
- asp.net 中用easyui中的treegird的简单使用
几乎每个‘数人头’项目中都会用到的功能,这里先记下来,以后直接到这里复制代码就行了,ASP.NET MVC中的使用 数据库用户表中的除了有个parentid父级ID外,我还多加了以个字段,parent ...
- Win7多用户情况下,指定某一用户为自动登陆-解决办法
转自:http://sbiuggypm.themex.net/archives/605 许久没更新博客了,但从后台可以查看到,有不少朋友还是几乎每天来逛一逛,很对不起的是最近都没更新啥内容.真是不好意 ...
- create-react-app的使用及原理分析
create-react-app 是一个全局的命令行工具用来创建一个新的项目 react-scripts 是一个生成的项目所需要的开发依赖 一般我们开始创建react web应用程序的时候,要自己通过 ...
- gf框架之grpool - 高性能的goroutine池
Go语言中的goroutine虽然相对于系统线程来说比较轻量级,但是在高并发量下的goroutine频繁创建和销毁对于性能损耗以及GC来说压力也不小.充分将goroutine复用,减少goroutin ...
- 探寻main函数的“标准”写法,以及获取main函数的参数、返回值
main函数表示法 很多同学在初学C或者C++时,都见过各种各样的main函数表示法: main(){/*...*/} void main(){/*...*/} int main(){/ ...
- raft Paxos
CONSENSUS: BRIDGING THEORY AND PRACTICE https://ramcloud.stanford.edu/~ongaro/thesis.pdf https://web ...
- VS2015终极卸载方法
今天打开VS2015发现出问题了,总是停止响应,去控制面板里卸载结果像下面这样,卸载出错!于是我有开始折腾了,重新安装一遍然后,还是有问题,在卸载还是出错于是我决定通过安装介质卸载,结果,悲剧的是,启 ...
- python(50):python 向上取整 ceil 向下取整 floor 四舍五入 round
取整:ceil 向下取整:floor 四舍五入:round 使用如下:
- Android 微信支付资料收集
老板要求支持微信支付,收集了些资料做后期参考 http://www.360doc.com/content/15/0214/10/7044580_448519997.shtml http://www.t ...
- Pythonic版冒泡排序和快速排序(附:直接插入排序)
[本文出自天外归云的博客园] 冒泡排序:就是每次排序选最大元素到数组a的最后,排 len(a)-1 次.也就是两个for循环: 1. 外层是待排数组长度的循环,从待排数组长度(初始待排数组长度等于数组 ...