LibSVM C/C++
本系列文章由 @YhL_Leo 出品,转载请注明出处。
文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779
在LibSVM
的库的svm.h
头文件中定义了四个主要结构体:
1 训练模型的结构体
struct svm_problem
{
int l; // total number of samples
double *y; // label of each sample
struct svm_node **x; // feature vector of each sample
};
样本的类别通常使用+1
与-1
进行标识。如果样本的类别,则分类的准确率也就无法计算。
2 数据节点的结构体
struct svm_node
{
int index;
double value;
};
数据组织结构如图1所示:
3 模型参数结构体
struct svm_parameter
{
int svm_type;
int kernel_type;
int degree; /* for poly */
double gamma; /* for poly/rbf/sigmoid */
double coef0; /* for poly/sigmoid */
/* these are for training only */
double cache_size; /* in MB */
double eps; /* stopping criteria */
double C; /* for C_SVC, EPSILON_SVR and NU_SVR */
int nr_weight; /* for C_SVC */
int *weight_label; /* for C_SVC */
double* weight; /* for C_SVC */
double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */
double p; /* for EPSILON_SVR */
int shrinking; /* use the shrinking heuristics */
int probability; /* do probability estimates */
};
其中,各个参数的含义为:
-s svm_type : set type of SVM (default 0)
0 -- C-SVC
1 -- nu-SVC
2 -- one-class SVM
3 -- epsilon-SVR
4 -- nu-SVR
-t kernel_type : set type of kernel function (default 2)
0 -- linear: u'*v
1 -- polynomial: (gamma*u'*v + coef0)^degree
2 -- radial basis function: exp(-gamma*|u-v|^2)
3 -- sigmoid: tanh(gamma*u'*v + coef0)
-d degree : set degree in kernel function (default 3)
-g gamma : set gamma in kernel function (default 1/num_features)
-r coef0 : set coef0 in kernel function (default 0)
-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
-m cachesize : set cache memory size in MB (default 100)
-e epsilon : set tolerance of termination criterion (default 0.001)
-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)
SVM模型类型和核函数类型:
enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */
enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */
4 训练输出模型结构体
struct svm_model
{
struct svm_parameter param; /* parameter */
int nr_class; /* number of classes, = 2 in regression/one class svm */
int l; /* total #SV */
struct svm_node **SV; /* SVs (SV[l]) */
double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */
double *probA; /* pariwise probability information */
double *probB;
int *sv_indices; /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */
/* for classification only */
int *label; /* label of each class (label[k]) */
int *nSV; /* number of SVs for each class (nSV[k]) */
/* nSV[0] + nSV[1] + ... + nSV[k-1] = l */
/* XXX */
int free_sv; /* 1 if svm_model is created by svm_load_model*/
/* 0 if svm_model is created by svm_train */
};
5 使用方法
以LibSVM
提供的样本特征集heart_scale
为例,首先需要读取样本特征数据,可以利用svm-train.c
文件中的read_problem
函数,为了方便使用,对其进行了重写改写:
// TrainingDataLoad.h
/*
Load training data from svm format file.
- Editor: Yahui Liu.
- Data: 2015-11-30
- Email: yahui.cvrs@gmail.com
- Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/
#ifndef TRAINING_DATA_LOAD_H
#define TRAINING_DATA_LOAD_H
#pragma once
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include <errno.h>
#include "svm.h"
//#include "svm-scale.c"
using namespace std;
#define MAX_LINE_LEN 1024
class TrainingDateLoad
{
public:
TrainingDateLoad()
{
line = NULL;
}
~TrainingDateLoad()
{
line = NULL;
}
public:
char* line;
// public:
// static struct svm_parameter _paramInit;
public:
/*! load svm model */
void loadModel( std::string filename, struct svm_model*& model);
/*! skip the target */
void svmSkipTarget( char*& p);
/* skip the element */
void svmSkipElement( char*& p);
void initialParams( struct svm_parameter& param );
/*! load training data */
void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param );
char* readline(FILE *input);
void exit_input_error(int line_num)
{
cout << "Wrong input format at line: " << line_num << endl;
exit(1);
}
};
#endif // TRAINING_DATA_LOAD_H
// TrainingDataLoad.cpp
#include "TrainingDataLoad.h"
void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model)
{
model = svm_load_model(filename.c_str());
}
void TrainingDateLoad::svmSkipTarget(char*& p)
{
while(isspace(*p)) ++p;
while(!isspace(*p)) ++p;
}
void TrainingDateLoad::svmSkipElement(char*& p)
{
while(*p!=':') ++p;
++p;
while(isspace(*p)) ++p;
while(*p && !isspace(*p)) ++p;
}
void TrainingDateLoad::initialParams( struct svm_parameter& param )
{
// default values
param.svm_type = C_SVC;
param.kernel_type = RBF;
param.degree = 3;
param.gamma = 0; // 1/num_features
param.coef0 = 0;
param.nu = 0.5;
param.cache_size = 100;
param.C = 1;
param.eps = 1e-3;
param.p = 0.1;
param.shrinking = 1;
param.probability = 0;
param.nr_weight = 0;
param.weight_label = NULL;
param.weight = NULL;
}
void TrainingDateLoad::readProblem( std::string filename,
struct svm_problem& prob, struct svm_parameter& param )
{
int max_index, inst_max_index, i;
size_t elements, j;
FILE *fp = fopen(filename.c_str(),"r");
char *endptr;
char *idx, *val, *label;
if(fp == NULL)
{
fprintf(stderr,"can't open input file %s\n",filename);
exit(1);
}
prob.l = 0;
elements = 0;
line = new char[MAX_LINE_LEN];
while(readline(fp)!=NULL)
{
char *p = strtok(line," \t"); // label
// features
while(1)
{
p = strtok(NULL," \t");
if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
break;
++elements;
}
++elements;
++prob.l;
}
rewind(fp);
prob.y = new double[prob.l];
prob.x = new struct svm_node *[prob.l];
struct svm_node *x_space = new struct svm_node[elements];
max_index = 0;
j=0;
for(i=0;i<prob.l;i++)
{
inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
readline(fp);
prob.x[i] = &x_space[j];
label = strtok(line," \t\n");
if(label == NULL) // empty line
exit_input_error(i+1);
prob.y[i] = strtod(label,&endptr);
if(endptr == label || *endptr != '\0')
exit_input_error(i+1);
while(1)
{
idx = strtok(NULL,":");
val = strtok(NULL," \t");
if(val == NULL)
break;
errno = 0;
x_space[j].index = (int) strtol(idx,&endptr,10);
if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
exit_input_error(i+1);
else
inst_max_index = x_space[j].index;
errno = 0;
x_space[j].value = strtod(val,&endptr);
if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
exit_input_error(i+1);
++j;
}
if(inst_max_index > max_index)
max_index = inst_max_index;
x_space[j++].index = -1;
}
if(param.gamma == 0 && max_index > 0)
param.gamma = 1.0/max_index;
if(param.kernel_type == PRECOMPUTED)
for(i=0;i<prob.l;i++)
{
if (prob.x[i][0].index != 0)
{
fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
exit(1);
}
if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
{
fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
exit(1);
}
}
fclose(fp);
}
char* TrainingDateLoad::readline(FILE *input)
{
int len;
if(fgets(line,MAX_LINE_LEN,input) == NULL)
return NULL;
int max_line_len = MAX_LINE_LEN;
while(strrchr(line,'\n') == NULL)
{
max_line_len *= 2;
line = (char *) realloc(line,max_line_len);
len = (int) strlen(line);
if(fgets(line+len,max_line_len-len,input) == NULL)
break;
}
return line;
}
将样本训练与预测进行改写:
// LibSVMTools.h
/*
LibSVM train and predict tools.
- Editor: Yahui Liu.
- Data: 2015-12-3
- Email: yahui.cvrs@gmail.com
- Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/
#ifndef LIBSVM_TOOL_H
#define LIBSVM_TOOL_H
#pragma once
#include <iostream>
#include <string>
#include "svm.h"
#include "TrainingDataLoad.h"
class LibSVMTools
{
public:
LibSVMTools(){}
~LibSVMTools(){}
public:
/*!
- featureFile: features of images saved in libsvm format.
- saveModelFile: save the trained model file.
**/
void libSvmTrain(std::string featureFile, std::string saveModelFile);
/*!
- featureFile: features of images saved in libsvm format.
- modelFile: libsvm trained model.
- savePredictFile: save the predicting results.
**/
void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile);
};
#endif // LIBSVM_TOOL_H
// LibSVMTools.cpp
#include "LibSVMTools.h"
void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile)
{
struct svm_parameter param;
struct svm_problem prob;
TrainingDateLoad* trainData = new TrainingDateLoad;
trainData->initialParams( param );
trainData->readProblem(featureFile, prob, param);
const char*errorMsg = svm_check_parameter(&prob, ¶m);
if ( errorMsg )
{
cout << errorMsg << endl;
return;
}
struct svm_model *model = svm_train(&prob, ¶m);
#if 1
cout << "svm_type: " << model->param.svm_type << endl <<
"kernel_type: " << model->param.kernel_type << endl <<
"gamma: " << model->param.gamma << endl <<
"nr_class: " << model->nr_class << endl <<
"total_sv: " << model->l << endl <<
"rho: " << model->rho[0] << endl <<
"label: " << model->label[0] << " " << model->label[1] << endl <<
"nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl;
#endif
int saveModel = svm_save_model( saveModelFile.c_str(), model );
}
void LibSVMTools::libSvmPredict(std::string featureFile,
std::string modelFile, std::string savePredictFile)
{
struct svm_parameter param;
struct svm_problem prob;
TrainingDateLoad * trainData = new TrainingDateLoad;
trainData->initialParams( param );
trainData->readProblem(featureFile, prob, param);
struct svm_model* model;
trainData->loadModel(modelFile.c_str(), model);
float correct(0.0); // all correct
float uncorrect_1(0.0); // pos to neg
float uncorrect_2(0.0); // neg to pos
if ( prob.l )
{
const int nCount = prob.l;;
ofstream outfile( savePredictFile, ios::out );
for( int i=0; i<nCount; i++ )
{
double label = svm_predict(model, prob.x[i]);
if ( label == prob.y[i] )
{
correct ++;
}
else if ( label == -1.0 )
{
uncorrect_1 ++;
}
else
{
uncorrect_2 ++;
}
outfile << label << endl;
}
#if 1
cout << "total data count: " << nCount << endl <<
"classification correct: " << correct << endl <<
"pos to neg count: " << uncorrect_1 << endl <<
"neg to pos count: " << uncorrect_2 << endl;
cout << "Accuracy: " << static_cast<float>(correct/nCount)
<< "(" << correct << "/" << nCount << ")" << endl;
#endif
outfile.close();
}
}
用例Demo:
// train
#include "LibSVMTools.h"
void main()
{
std::cout <<
"************************************************************" << endl <<
"** PROGRAM: LibSVM model training. **" << endl <<
"** **" << endl <<
"** Author: Yahui Liu. **" << endl <<
"** School of Remote Sensing & Inf. Eng. **" << endl <<
"** Wuhan University, Hubei, P.R. China **" << endl <<
"** Email: yahui.cvrs@gmail.com **" << endl <<
"** Create time: Dec. 1, 2015 **" << endl <<
"************************************************************" << endl;
string filename = "..\\..\\..\\Data\\heat_scale";
std::string savefielname = "..\\..\\..\\Data\\train.model";
LibSVMTools* libsvm = new LibSVMTools();
libsvm->libSvmTrain(filename, savefielname);
delete libsvm;
}
/*------------------------------------------------------------------------------------*/
// predict
#include "LibSVMTools.h"
void main()
{
std::cout <<
"************************************************************" << endl <<
"** PROGRAM: LibSVM predict. **" << endl <<
"** **" << endl <<
"** Author: Yahui Liu. **" << endl <<
"** School of Remote Sensing & Inf. Eng. **" << endl <<
"** Wuhan University, Hubei, P.R. China **" << endl <<
"** Email: yahui.cvrs@gmail.com **" << endl <<
"** Create time: Dec. 1, 2015 **" << endl <<
"************************************************************" << endl;
std::string featureFile = "..\\..\\..\\Data\\heart_scale";
std::string modelFile = "..\\..\\..\\Data\\train.model";
std::string savePredictFile = "..\\..\\..\\Data\\predict.out";
LibSVMTools* libsvm = new LibSVMTools();
libsvm->libSvmPredict(featureFile, modelFile, savePredictFile);
delete libsvm;
}
LibSVM C/C++的更多相关文章
- 6.LibSVM核函数
libsvm的核函数类型(svmtrain.c注释部分): "-t kernel_type : set type of kernel function (default 2)\n" ...
- libsvm的数据格式及制作
1.libsvm数据格式 libsvm使用的训练数据和检验数据文件格式如下: [label] [index1]:[value1] [index2]:[value2] … [label] [index1 ...
- libsvm下的windows版本中的工具的使用
下载的libsvm包里面已经为我们编译好了(windows).进入libsvm\windows,可以看到这几个exe文件: a.svm-toy.exe:图形界面,可以自己画点,产生数据等. b.svm ...
- 【转】Windows下使用libsvm中的grid.py和easy.py进行参数调优
libsvm中有进行参数调优的工具grid.py和easy.py可以使用,这些工具可以帮助我们选择更好的参数,减少自己参数选优带来的烦扰. 所需工具:libsvm.gnuplot 本机环境:Windo ...
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
#写libsvm格式 数据 write libsvm #!/usr/bin/env python #coding=gbk # ================================= ...
- LibSVM for Python 使用
经历手写SVM的惨烈教训(还是太年轻)之后,我决定使用工具箱/第三方库 Python libsvm的GitHub仓库 LibSVM是开源的SVM实现,支持C, C++, Java,Python , R ...
- libsvm简介和函数调用参数说明
1. libSVM简介 libSVM是台湾林智仁(Chih-Jen Lin) 教授2001年开发的一套支持向量机库,这套库运算速度挺快,可以很方便的对数据做分类或回归.由于libSVM程序小 ...
- Libsvm Matlab 快速安装教程 (适用于Win7+, 64bit, and Matlab2016a+)
近日在开始学习Machine Learning SVM 相关算法,将Matlab平台安装SVM的步骤记录如下,亲测可用: 开发环境: Windows 8 64 bit, Matlab 2016a, S ...
- win7下matlab2016a配置libsvm
1.下载libsvm https://www.csie.ntu.edu.tw/~cjlin/libsvm/ 2.解压到matlab2016a的安装目录的toolbox下 例如我的D:\Program ...
- WEKA运行LIBSVM出现problem evaluating classifier:rand
原来这个实验已经做了的.也出现了些问题,但是上网找到了解决方法,那个时候是完成数据挖掘的课程论文,用WEKA运行LIBSVM,也没有很深入,简单跑出结果就算了. 这次想着研讨会就讲这个,想着深入进去, ...
随机推荐
- [剑指offer] 5. 用两个栈实现队列+[剑指offer]30. 包含min函数的栈(等同于leetcode155) +[剑指offer]31.栈的压入、弹出序列 (队列 栈)
c++里面stack,queue的pop都是没有返回值的, vector的pop_back()也没有返回值. 思路: 队列是先进先出 , 在stack2里逆序放置stack1的元素,然后stack2. ...
- Django REST Framework - 分页 - 渲染器 - 解析器
为什么要使用分页? 我们数据表中可能会有成千上万条数据,当我们访问某张表的所有数据时,我们不太可能需要一次把所有的数据都展示出来,因为数据量很大,对服务端的内存压力比较大还有就是网络传输过程中耗时也会 ...
- 题解 CF1051F 【The Shortest Statement】
这道题思路比较有意思,第一次做完全没想到点子上... 看到题目第一反应是一道最短路裸题,但是数据范围1e5说明完全不可能. 这个时候可以观察到题目给出了一个很有意思的条件,就是说边最多比点多20. 这 ...
- js手动定时清除localStorage
<script type="text/javascript"> // 假设要保存变量 a 的值,过期时间为 3600秒 // 保存值 var obj = new Obj ...
- fedora linux源代码下载
yumdownloader --source kernel 如果是下载insight 就是 yumdownloader --source insight 下载到的是当前目录. 然后在用rpm2cpio ...
- 【POJ 2485】 Highways
[POJ 2485] Highways 最小生成树模板 Prim #include using namespace std; int mp[501][501]; int dis[501]; bool ...
- 阿里云server部署架构
近期要上马一个项目,客户要求所有部署到阿里云的server,做了一个阿里云的部署方案. 上图: watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvc21hbGx ...
- How to start/stop DB instance of Oracle under Linux
All below actions should be executed with "oracle" user account 1. Check the status of lis ...
- iOS开发实践之xib载入注意问题
xib都会addSubview加入到控制器view中时程序崩溃.错误提示: 'NSInvalidArgumentException', reason: '-[ UITapGestureRecogniz ...
- m_Orchestrate learning system---二十七、修改时如何快速找到作用位置
m_Orchestrate learning system---二十七.修改时如何快速找到作用位置 一.总结 一句话总结:找人,找起作用的位置真的重要,找到就事半功倍了 加载页面的时候观察在f12的e ...