自己实现的SVM源码
首先是DATA类
import java.awt.print.Printable;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner; public class Data {
public Map<List<Double>, Integer> getTrainData() {
Map<List<Double>, Integer> data=new HashMap<List<Double>, Integer>(); try {
Scanner in=new Scanner(new File("G://download//testSet.txt"));
while(in.hasNextLine())
{
String str =in.nextLine();
String []strs=str.trim().split("\t");
List<Double> pointTmp=new ArrayList<>();
for(int i=0;i<strs.length-1;i++)
pointTmp.add(Double.parseDouble(strs[i]));
data.put(pointTmp, Integer.parseInt(strs[strs.length-1]));
}
} catch (FileNotFoundException e) {
// TODO: handle exception
e.printStackTrace();
} return data;
} public static void main(String[] args)
{
Data data=new Data();
data.getTrainData();
}
}
SVM类:
import java.awt.print.Printable;
import java.io.FileNotFoundException;
import java.io.ObjectInputStream.GetField;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry; public class SVM {
private List<ArrayList<Double>> trainData;
private List<Integer> labelTrainData;
private double sigma;
private double C;
private List<Double> alpha;
private double b;
private List<Double> E;
private int N;
private int dim;
private double tol;
private double eta;
private double eps;
private double eps2; public boolean satisfyKkt(int id)
{
double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x)
if(Math.abs(this.alpha.get(id))<=this.eps)
{
if(ypgx-1<-this.tol) return false;
}
else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps)
{
if(ypgx-1>this.tol) return false;
}
else {
if(Math.abs(ypgx-1)>this.tol) return false;
}
return true;
} public void updateE() { for(int i=0;i<this.N;i++)
{
double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i);
this.E.set(i, Ei);
}
} public double kernelLinear(List<Double> X,List<Double> Y) {
//linear kernel function
int len=Y.size();
double s=0;
for(int i=0;i<len;i++)
s+=X.get(i)*Y.get(i);
return s;
} public double kernelRBF(List<Double> X,List<Double> Y)
{
//gauss kernel function int len=Y.size();
double s=0;
for(int i=0;i<len;i++)
s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i));
s=Math.exp(-s/(2*Math.pow(this.sigma, 2)));
return s;
} public double getGx(List<Double> X)
{
//calculate wx+b value
double s=0;
for(int i=0;i<this.N;i++)
{
//for debug
double debug1=kernelRBF(X, this.trainData.get(i));
double debug2=this.alpha.get(i); s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i));
}
s+=this.b;
return s;
} public int update(int x1,int x2)
{
double low=0;
double high=0;
if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2))
{
low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C);
high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1));
}
else
{
low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1));
high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C);
}
double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta;
double newAlpha1=0; if(newAlpha2>high) newAlpha2=high;
else if(newAlpha2<low) newAlpha2=low;
newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2); if(Math.abs(newAlpha1)<=this.eps)
newAlpha1=0;
if(Math.abs(newAlpha2)<=this.eps)
newAlpha2=0;
if(Math.abs(newAlpha1-this.C)<=this.eps)
newAlpha1=this.C;
if(Math.abs(newAlpha2-this.C)<=this.eps)
newAlpha2=this.C;
if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2)
return 0;
if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2)
return 0; double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b;
double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b; if(newAlpha1>0&&newAlpha1<this.C)
this.b=b1;
else if(newAlpha2>0&&newAlpha2<this.C)
this.b=b2;
else
this.b=(b1+b2)/2; this.alpha.set(x1,newAlpha1);
this.alpha.set(x2,newAlpha2);
updateE();
return 1;
}
public int selectAlpha2(int x1) { int x2=-1;
double maxDiff=-1;
//first select x2 from 0<a<c to max(E(x1)-E(x2)) for(int i=0;i<this.N;++i)
{
if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue;
double diff=Math.abs(this.E.get(x1)-this.E.get(i));
if(diff>maxDiff)
{
maxDiff=diff;
x2=i;
}
} //second calculate eta (eta!=0)
if(x2!=-1)
{
this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2));
if(eta!=0) return x2;
} //third if cannot find in the whole train set
for(int i=0;i<this.N;i++)
{
if(i==x1) continue;
this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i));
if(Math.abs(this.eta)>this.eps) return i;
}
return -1; } public void SMO() {
//to solve alpha
int numChanged=0;
int cnt=0;
while(true)
{
cnt++;
System.out.println(cnt); numChanged=0;
for(int x1=0;x1<this.N;++x1)
{
if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue;
if(!satisfyKkt(x1))
{
int x2=selectAlpha2(x1);
if(x2==-1) continue;
numChanged+=update(x1, x2);
}
}
if(numChanged==0)
{
for(int x1=0;x1<this.N;++x1)
{
if(!satisfyKkt(x1))
{
int x2=selectAlpha2(x1);
if(x2==-1) continue;
update(x1, x2);
numChanged++;
}
}
}
if(numChanged==0)
break;
}
} public SVM() {
//load train data Data data=new Data();
Map<List<Double>, Integer> Datas=data.getTrainData();
int totalData=Datas.size();
this.trainData=new ArrayList<ArrayList<Double>>();
this.labelTrainData=new ArrayList<Integer>();
this.alpha=new ArrayList<Double>();
this.E=new ArrayList<Double>(); int i=0;
for(Map.Entry<List<Double>, Integer> entry: Datas.entrySet())
{
this.trainData.add((ArrayList<Double>) entry.getKey());
this.labelTrainData.add(entry.getValue());
this.alpha.add(0.0);
this.E.add(0.0-this.labelTrainData.get(i));
i++;
}
this.N=this.labelTrainData.size();
this.dim=this.trainData.get(0).size(); this.sigma=12;//sigma=1
this.C=0.5;//c=6
this.b=0.0;
this.tol=0.001;
this.eta=0;
this.eps=0.0000001;
this.eps2=0.00001;
} public double getB() {
//get b value
return this.b;
}
public double[] getLinearW() {
double []w=new double[this.N];
for(int i=0;i<this.N;i++)
{
for(int j=0;j<this.dim;j++)
{
w[j]+=this.alpha.get(i)*this.labelTrainData.get(i)*this.trainData.get(i).get(j);
}
}
return w;
} public int predict(List<Double> x)
{
int ans=1;
double sum=0;
for(int i=0;i<this.N;i++)
{
sum+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(x, this.trainData.get(i));
}
sum+=b;
if(sum>0)
ans=1;
else
ans=-1; return ans;
}
public static void main(String[] args) throws FileNotFoundException { SVM s=new SVM();
s.SMO();
PrintWriter out=new PrintWriter("g://download//resultpoints.txt");
for(int i=0;i<s.N;i++)
{
out.write((s.trainData.get(i).get(0)).toString());
out.write("\t");
out.write((s.trainData.get(i).get(1)).toString());
out.write("\t");
out.write(Integer.toString(s.predict(s.trainData.get(i))));
out.write("\n");
}
out.close();
//if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction
double w[]=s.getLinearW();
System.out.println(w[0]+" "+w[1]+" "+s.b+"======");
} }
用线性核函数实现的SVM的到的分类结果
画图,是用python代码
from numpy import *
import matplotlib
import matplotlib.pyplot as plt
import numpy as np with open("g://download/myresult.txt") as f1:
data=f1.readlines(); plt.figure(figsize=(8, 5), dpi=80)
axes = plt.subplot(111)
type1_x = []
type1_y = []
type2_x = []
type2_y = []
for line in data:
x=line.strip().split('\t');
x1=float(x[0])
x2=float(x[1])
x3=int(x[2]) if x3==1:
type1_x.append(x1)
type1_y.append(x2)
else:
type2_x.append(x1)
type2_y.append(x2) type1 = axes.scatter(type1_x, type1_y,s=40, c='red' )
type2 = axes.scatter(type2_x, type2_y, s=40, c='green') W1 = 0.8148005405344305
W2 = -0.27263471796762484
B = -3.8392586254518437
x = np.linspace(-4,10,200)
y = (-W1/W2)*x+(-B/W2)
axes.plot(x,y,'b',lw=3) plt.xlabel('x1')
plt.ylabel('x2') axes.legend((type1, type2), ('0', '1'),loc=1)
plt.show() #0.8148005405344305 -0.27263471796762484 -3.8392586254518437
用高斯核,当C=6,sigma=1时候
高斯核,当c=0.5,sigma=1时候
当C=0.5,sigma=12时候
说明C的大小和sigma的大小对高斯核影响是很大的
sigma是高斯核函数的参数
自己实现的SVM源码的更多相关文章
- EasyPR源码剖析(1):概述
EasyPR(Easy to do Plate Recognition)是本人在opencv学习过程中接触的一个开源的中文车牌识别系统,项目Git地址为https://github.com/liuru ...
- Mahout源码目录说明&&算法集
Mahout源码目录说明 mahout项目是由多个子项目组成的,各子项目分别位于源码的不同目录下,下面对mahout的组成进行介绍: 1.mahout-core:核心程序模块,位于/core目录下: ...
- 近200篇机器学习&深度学习资料分享(含各种文档,视频,源码等)(1)
原文:http://developer.51cto.com/art/201501/464174.htm 编者按:本文收集了百来篇关于机器学习和深度学习的资料,含各种文档,视频,源码等.而且原文也会不定 ...
- Ubentu编译Android源码(AOSP)
前言: 一直想要编译一下Android 源码,之前去google 看,下载要下载repo. 当时很懵逼,repo 是个什么?(repo 是一个python 脚本,因为Android 源码git 仓库太 ...
- Android FrameWork 学习之Android 系统源码调试
这是很久以前访问掘金的时候 无意间看到的一个关于Android的文章,作者更细心,分阶段的将学习步骤记录在自己博客中,我觉得很有用,想作为分享同时也是留下自己知识的一些欠缺收藏起来,今后做项目的时候会 ...
- Python的开源人脸识别库:离线识别率高达99.38%(附源码)
Python的开源人脸识别库:离线识别率高达99.38%(附源码) 转https://cloud.tencent.com/developer/article/1359073 11.11 智慧上云 ...
- GWO(灰狼优化)算法MATLAB源码逐行中文注解(转载)
以优化SVM算法的参数c和g为例,对GWO算法MATLAB源码进行了逐行中文注解. tic % 计时器 %% 清空环境变量 close all clear clc format compact %% ...
- FaceNet pre-trained模型以及FaceNet源码使用方法和讲解
Pre-trained models Model name LFW accuracy Training dataset Architecture 20180408-102900 0.9905 CASI ...
- KNN算法介绍及源码实现
一.KNN算法介绍 邻近算法,或者说K最邻近(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一.所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它 ...
随机推荐
- 深入理解JVM一JVM内存模型
前言 JVM一直是java知识里面进阶阶段的重要部分,如果希望在java领域研究的更深入,则JVM则是如论如何也避开不了的话题,本系列试图通过简洁易读的方式,讲解JVM必要的知识点. 一.运行流程 我 ...
- 【BZOJ4200】【NOI2015】小园丁与老司机(动态规划,网络流)
[BZOJ4200][NOI2015]小园丁与老司机(动态规划,网络流) 题面 BZOJ权限题,洛谷链接 题解 一道二合一的题目 考虑第一问. 先考虑如何计算六个方向上的第一个点. 左右上很好考虑,只 ...
- php-fpm: hundreds of seconds in the log
favoriteI have nginx+php-fpm web serverSo I've noticed in php5-fpm.log many strange lines:[03-Sep-20 ...
- 【CF113D】Museum
Portal --> cf113D Solution 额题意的话大概就是给一个无向图然后两个人给两个出发点,每个点每分钟有\(p[i]\)的概率停留,问这两个人在每个点相遇的概率是多少 如果说我 ...
- docker-compose写法收集
version: '3.3' services: php: image: docker.ksyun.com/php7.:latest volumes: - ./env/log/apps:/data/l ...
- IO流-文件拷贝
其实文件的拷贝还是文件读取写入的应用,实际是读取此路径上的文件,然后写入到指定路径下的文件. 代码举例: import java.io.*; import java.lang.*; class Tes ...
- R0—New packages for reading data into R — fast
小伙伴儿们有福啦,2015年4月10日,Hadley Wickham大牛(开发了著名的ggplots包和plyr包等)和RStudio小组又出新作啦,新作品readr包和readxl包分别用于R读取t ...
- 【CodeForces】866D. Buy Low Sell High
[题意]已知n天股价,每天可以买入一股或卖出一股或不作为,最后必须持0股,求最大收益. [算法]堆 贪心? [题解] 不作为思想:[不作为=买入再卖出] 根据不作为思想,可以推出中转站思想. 中转站思 ...
- 【CodeForces】576 B. Invariance of Tree
[题目]B. Invariance of Tree [题意]给定n个数的置换,要求使n个点连成1棵树,满足u,v有边当且仅当a[u],a[v]有边,求一种方案或无解.n<=10^5. [算法]数 ...
- el-option > 1500 条时的卡顿问题
本文地址: http://www.cnblogs.com/veinyin/p/8473938.html 在做项目时遇到的一个问题. 项目是基于 Vue 框架做的. select 的 option 是 ...