前言

项目有一个模块需要将不同类别的图片进行分类,共有三个类别,使用SVM实现分类。

实现步骤:

1.创建训练样本库;

2.训练、测试SVM模型;

3.SVM的数据要求;

实现系统:

windows_x64、opencv2.4.10、 VS2013

实现过程:

1.创建训练样本库;

1)将图片以包含类别的名称进行命名,比如0(1).jpg等等;

2)将所有已命名正确的训练样本保存在同一个文件夹中;

3)在训练样本库的文件夹目录下创建python源文件;

python代码:

  1. import sys
  2. import os
  3. import string
  4. import re
  5.  
  6. if __name__=='__main__':
  7. print('Begin generate path and label.')
  8. path_file=open('train_path.txt','w')
  9. path='E:/carriage_recognition/redplate_detection/svm_train_test/data/train/model'
  10. pic_type='.png'
  11. pat=re.compile(r'^(\d+)')
  12. files=os.listdir(path)
  13. files_tmp=[]
  14. for i in files:
  15. if pic_type in i and not os.path.isdir(path+'/'+i):
  16. files_tmp.append(i)
  17. files=files_tmp
  18. for file in range(len(files)):
  19. ret=pat.match(files[file])
  20. path_file.write(path+'/'+files[file]+'\n')
  21. if file<len(files)-1:
  22. path_file.write(ret.group(1)+'\n')
  23. else:
  24. path_file.write(ret.group(1))
  25. path_file.close()
  26. print('finish......')

4)运行代码,即可,生成包含图片名称和类别的文本文件,用于SVM训练过程中读物图片获取相应的类别标签;

  1. E:/carriage_recognition/redplate_detection/svm_train_test/data/train/model/0 (1).png
  2. 0
  3. E:/carriage_recognition/redplate_detection/svm_train_test/data/train/model/0 (10).png
  4. 0

奇数行表示训练样本图片的路径名称;偶数行表示该图片的类别标签;

2.训练、测试SVM模型;

1)image.h,主要实现过程的代码;

  1. #include <fstream>
  2. #include <vector>
  3. #include<direct.h>
  4. #include <opencv2\core\core.hpp> //红牌事件检测头文件
  5. #include <opencv2\opencv.hpp>
  6.  
  7. using namespace std;
  8. using namespace cv;
  9.  
  10. #define ON_STUDY 0
  11. #define Num 3 //类别数目
  12. #define STANDARD_ROW 65
  13. #define STANDARD_COL 85
  14.  
  15. #define STANDARD_ROW_CHOOSE 65
  16. #define STANDARD_COL_CHOOSE 85
  17.  
  18. #define CHANELS 1
  19. class NumTrainData
  20. {
  21. public:
  22. NumTrainData()
  23. {
  24. memset(data, 0, sizeof(data));
  25. result = -1;
  26. }
  27. public:
  28. float data[CHANELS*STANDARD_COL_CHOOSE*STANDARD_ROW_CHOOSE];
  29. int result;
  30. };
  31.  
  32. vector<string> img_path;//输入文件名变量
  33. vector<string> img_test_path;//输入文件名变量
  34. vector<int> img_catg;
  35. vector<int> img_test_catg;
  36. int nLine = 0;
  37. string buf;
  38.  
  39. unsigned long n;
  40. vector<NumTrainData> buffer;
  41. int featureLen = CHANELS*STANDARD_COL_CHOOSE*STANDARD_ROW_CHOOSE;
  42.  
  43. char* test_path = "./test_path.txt";
  44. char* train_path = "./train_path.txt";
  45. //存放输出结果
  46. char* save_path = "./SVM_DATA_train_0.5_0.2.xml";
  47. ofstream matrix_config("./fusion_matrix_0.5_0.2.txt"); //存放混淆矩阵
  48. string save_wrong_results = "./wrong_0.5_0.2"; //存放识别错误的结果
  49.  
  50. void ReadTrainData()
  51. {
  52. ifstream svm_data(train_path);//训练样本图片的路径都写在这个txt文件中,使用python可以得到这个txt文件
  53. while (svm_data)//将训练样本文件依次读取进来
  54. {
  55. if (getline(svm_data, buf))
  56. {
  57. nLine++;
  58. if (nLine % 2 == 0)//注:奇数行是图片全路径,偶数行是标签
  59. {
  60. img_catg.push_back(atoi(buf.c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错
  61. }
  62. else
  63. {
  64. img_path.push_back(buf);//图像路径
  65. }
  66. }
  67. }
  68. svm_data.close();//关闭文件
  69. }
  70.  
  71. void ReadTestData()
  72. {
  73. ifstream svm_data(test_path);//训练样本图片的路径都写在这个txt文件中,使用python可以得到这个txt文件
  74. while (svm_data)//将训练样本文件依次读取进来
  75. {
  76. if (getline(svm_data, buf))
  77. {
  78. nLine++;
  79. if (nLine % 2 == 0)//注:奇数行是图片全路径,偶数行是标签
  80. {
  81. img_test_catg.push_back(atoi(buf.c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错
  82. }
  83. else
  84. {
  85. img_test_path.push_back(buf);//图像路径
  86. }
  87. }
  88. }
  89. svm_data.close();//关闭文件
  90. }
  91.  
  92. void LoadTrainData()
  93. {
  94. Mat src; //= Mat::zeros(rows, cols, CV_8UC1);
  95. Mat dst;
  96. NumTrainData rtd;
  97. cout << "Begin load training data...." << endl;
  98. for (int i = 0; i < img_path.size(); i++)
  99. {
  100. rtd.result = img_catg[i];
  101.  
  102. int k = 0;
  103. if (CHANELS == 1) // gray image
  104. {
  105. src = imread(img_path[i].c_str(), 0);
  106. dst = src;
  107.  
  108. Mat temp = Mat::zeros(STANDARD_ROW, STANDARD_COL, CV_8UC1);
  109.  
  110. //尺寸归一化
  111. resize(dst, temp, temp.size());
  112.  
  113. float m[CHANELS*STANDARD_COL_CHOOSE*STANDARD_ROW_CHOOSE];
  114. for (int i = 0; i<STANDARD_ROW; i++)
  115. {
  116. for (int j = 0; j<STANDARD_COL; j++)
  117. {
  118. rtd.data[i * STANDARD_COL + j] = temp.at<uchar>(i, j);
  119. }
  120. }
  121. }
  122. else if (CHANELS == 3) // 3-channel image
  123. {
  124. src = imread(img_path[i].c_str(), 1);
  125. dst = src;
  126.  
  127. Mat temp = Mat::zeros(STANDARD_ROW, STANDARD_COL, CV_8UC1); //大小归一化
  128. resize(dst, temp, temp.size());
  129. //cout << temp.channels() << endl;
  130.  
  131. for (int i = 0; i < STANDARD_ROW_CHOOSE; i++)
  132. {
  133. for (int j = 0; j < STANDARD_COL_CHOOSE; j++)
  134. {
  135. Vec3b& mp = temp.at<Vec3b>(i, j);
  136. float B = mp.val[0];
  137. //cout << "B=" << B << endl;
  138. float G = mp.val[1];
  139. //cout << "G=" << B << endl;
  140. float R = mp.val[2];
  141. //cout << "R=" << B << endl;
  142.  
  143. rtd.data[k++] = B; //R
  144. rtd.data[k++] = G; //G
  145. rtd.data[k++] = R; //B
  146. }
  147. }
  148. }
  149. buffer.push_back(rtd);
  150. //cout << i << "th Image is loaded!" << endl;
  151. }
  152. cout << "Loading image finished!" << endl;
  153. }
  154.  
  155. void SVMPredict()
  156. {
  157. int x = 0;
  158. //_mkdir(save_test_preprocess.c_str());
  159. _mkdir(save_wrong_results.c_str());
  160. int fusion_matrix[Num][Num] = { 0 };
  161.  
  162. CvSVM svm;
  163. svm.load(save_path);
  164. Mat src,dst;
  165. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
  166.  
  167. NumTrainData rtd;
  168. int label = -1;
  169. int right = 0, error = 0;
  170. save_wrong_results += "/%d_true_%d_false_%d.png";
  171. //
  172. double ptrue_rtrue = 0;
  173. double ptrue = 0;
  174. double rtrue = 0;
  175. //
  176. for (int i = 0; i < img_test_path.size(); i++)
  177. {
  178. label = img_test_catg[i];
  179. rtd.result = label;
  180.  
  181. if (CHANELS == 1)
  182. {
  183. src = imread(img_test_path[i].c_str(), 0);
  184. dst = src;
  185.  
  186. Mat temp = Mat::zeros(STANDARD_ROW, STANDARD_COL, CV_8UC1); //大小归一化
  187. resize(dst, temp, temp.size());
  188.  
  189. for (int i = 0; i<STANDARD_ROW; i++)
  190. {
  191. for (int j = 0; j<STANDARD_COL; j++)
  192. {
  193. m.at<float>(0, j + i * STANDARD_COL) = temp.at<uchar>(i, j);
  194. }
  195. }
  196. normalize(m, m);
  197. }
  198. else if (CHANELS == 3) // 3-channel image
  199. {
  200. src = imread(img_test_path[i].c_str(), 1);
  201. dst = src;
  202.  
  203. Mat temp = Mat::zeros(STANDARD_ROW, STANDARD_COL, CV_8UC1); //大小归一化
  204. resize(dst, temp, temp.size());
  205.  
  206. int k = 0;
  207. for (int i = 0; i < STANDARD_ROW_CHOOSE; i++)
  208. {
  209. for (int j = 0; j < STANDARD_COL_CHOOSE; j++)
  210. {
  211. Vec3b& mp = temp.at<Vec3b>(i, j);
  212. float B = mp.val[0];
  213. float G = mp.val[1];
  214. float R = mp.val[2];
  215.  
  216. m.at<float>(0, k++) = B; //R
  217. m.at<float>(0, k++) = G; //G
  218. m.at<float>(0, k++) = R; //B
  219. }
  220. }
  221. }
  222.  
  223. int ret = svm.predict(m);
  224. //if (ret == 3)
  225. // ret = 1;
  226. cout << "Picture->" << img_test_path[i].c_str() << " : \nTrue label is [" << label << "] Predicted label is [" << ret << "]" << endl;
  227. //
  228. //计算FSCORE指标各个参数
  229. if (label == 0 && ret == 0) ptrue_rtrue++;//识别为红牌且实际为红牌;
  230. if (ret == 0) ptrue++;//识别为红牌的个数
  231. if (label == 0) rtrue++;//实际为红牌的个数
  232. //
  233. //存储错误图片
  234. if (label != ret)
  235. {
  236. x++;
  237. char filename[200];
  238. src = imread(img_test_path[i].c_str(), 1);
  239. sprintf(filename, save_wrong_results.c_str(), x, label, ret);
  240. imwrite(filename, src);
  241. }
  242. //计算混淆矩阵
  243. //fusion_matrix[label][ret] = fusion_matrix[label][ret] + 1;
  244. }
  245. //
  246. //FSCORE
  247. std::cout << "count_all: " << img_test_path.size() << std::endl;
  248. std::cout << "ptrue_rtrue: " << ptrue_rtrue << std::endl;
  249. std::cout << "ptrue: " << ptrue << std::endl;
  250. std::cout << "rtrue: " << rtrue << std::endl;
  251. //precise
  252. double precise = 0;
  253. if (ptrue != 0)
  254. {
  255. precise = ptrue_rtrue / ptrue;
  256. std::cout << "precise: " << precise << std::endl;
  257. }
  258. else
  259. {
  260. std::cout << "precise: " << "NA" << std::endl;
  261. }
  262. //recall
  263. double recall = 0;
  264. if (rtrue != 0)
  265. {
  266. recall = ptrue_rtrue / rtrue;
  267. std::cout << "recall: " << recall << std::endl;
  268. }
  269. else
  270. {
  271. std::cout << "recall: " << "NA" << std::endl;
  272. }
  273. //FSCORE
  274. double FScore = 0;
  275. if (precise + recall != 0)
  276. {
  277. FScore = 2 * (precise * recall) / (precise + recall);
  278. std::cout << "FScore: " << FScore << std::endl;
  279. }
  280. else
  281. {
  282. std::cout << "FScore: " << "NA" << std::endl;
  283. }
  284. //
  285. //for (size_t i = 0; i < Num; i++)
  286. //{
  287. // for (size_t j = 0; j < Num; j++)
  288. // {
  289. // matrix_config << fusion_matrix[i][j] << " ";
  290. // }
  291. // matrix_config << endl;
  292. //}
  293. //matrix_config.close();
  294. cout << "Task finished!output_matix" << endl;
  295. getchar();
  296. }
  297.  
  298. void SVMTrain(vector<NumTrainData>& trainData)
  299. {
  300. int testCount = trainData.size();
  301.  
  302. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
  303. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
  304. //Mat res = Mat::zeros(testCount, 1, CV_32SC1);
  305. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
  306.  
  307. for (int i = 0; i< testCount; i++)
  308. {
  309.  
  310. NumTrainData td = trainData.at(i);
  311. memcpy(m.data, td.data, featureLen * sizeof(float));
  312. normalize(m, m);
  313. memcpy(data.data + i*featureLen * sizeof(float), m.data, featureLen * sizeof(float));
  314. cout << td.result << endl;
  315. res.at<int>(i, 0) = td.result;
  316.  
  317. }
  318.  
  319. /////////////START SVM TRAINNING//////////////////
  320. //CvSVM svm = CvSVM();
  321. CvSVM svm;
  322. CvSVMParams param;
  323. CvTermCriteria criteria;
  324.  
  325. criteria = cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
  326. param = CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 0.5, 1.0, 0.2, 0.5, 0.1, NULL, criteria); //gamma=2;C=3
  327. cout << "Begin to train model using given train data.....\n Total training sample count is " << testCount << endl;
  328. svm.train(data, res, Mat(), Mat(), param);
  329. svm.save(save_path);
  330. cout << "Finish" << endl;
  331. }

  

2)主要函数说明;

2.1)SVMTrain函数主要实现模型的训练,其中训练参数使用RBF核,主要调整gamma和C这两个参数,固定一个参数调整另一个参数,最后确定模型参数分别为0.5/0.2;

2.2)SVMPredict函数主要实现对测试样本库的测试,并使用FScore指标测试SVM模型的性能;也可以使用混淆矩阵测试性能;

2.3)ReadTrainData/ ReadTestData函数分别用于获取训练和测试样本库图片的名称和类别标签;

2.4)LoadTrainData函数用于读取训练数据,并进行图像处理;

2.5)代码中使用整张图片的信息进行归一化之后作为特征;

3)主函数入口

  1. #include "image.h"
  2.  
  3. int main(int argc, char *argv[])
  4. {
  5. #if (ON_STUDY)
  6. ReadTrainData();
  7. LoadTrainData();
  8. SVMTrain(buffer);
  9. #else
  10. ReadTestData();
  11. SVMPredict();
  12. #endif
  13.  
  14. getchar();
  15. }

参数ON_STUDY表示选择进行训练或者测试的标志位;

3.SVM的数据要求;

需要说明的是就是SVM对于输入的数据类型是有要求的,即mTrainData(训练数据矩阵)以及mFlagPosNeg(标签矩阵)都必须为CV_32FC1类型(我的环境标签矩阵是CV_32SC1类型的),因此需要进行类型转换,而且必须保证转换完之后数值都不能大于1,这就给我们了两点启示:1)不能直接用下采样后的图像像素作为训练数据的输入,需要进行类型的归一化。2)类型转换时要使用normlize函数,保证其数值范围不大于1,而不能简单的使用Mat的成员函数coverto,只变类型不变数值范围。( 需要注意!)

问题:

该实现过程需要人工调整参数,比较繁琐,可以思考一下,是否还存在其他问题;

参考:

1.http://blog.csdn.net/firefight/article/details/6452188

2.opencv中SVM的那些事儿

 

SVM实现分类识别及参数调优(一)的更多相关文章

  1. 【机器学习基础】SVM实现分类识别及参数调优(二)

    前言 实现分类可以使用SVM方法,但是需要人工调参,具体过程请参考here,这个比较麻烦,小鹅不喜欢麻烦,正好看到SVM可以自动调优,甚好! 注意 1.reshape的使用: https://docs ...

  2. 从信用卡欺诈模型看不平衡数据分类(1)数据层面:使用过采样是主流,过采样通常使用smote,或者少数使用数据复制。过采样后模型选择RF、xgboost、神经网络能够取得非常不错的效果。(2)模型层面:使用模型集成,样本不做处理,将各个模型进行特征选择、参数调优后进行集成,通常也能够取得不错的结果。(3)其他方法:偶尔可以使用异常检测技术,IF为主

    总结:不平衡数据的分类,(1)数据层面:使用过采样是主流,过采样通常使用smote,或者少数使用数据复制.过采样后模型选择RF.xgboost.神经网络能够取得非常不错的效果.(2)模型层面:使用模型 ...

  3. Bayesian Optimization使用Hyperopt进行参数调优

    超参数优化 Bayesian Optimization使用Hyperopt进行参数调优 1. 前言 本文将介绍一种快速有效的方法用于实现机器学习模型的调参.有两种常用的调参方法:网格搜索和随机搜索.每 ...

  4. 【转】XGBoost参数调优完全指南(附Python代码)

    xgboost入门非常经典的材料,虽然读起来比较吃力,但是会有很大的帮助: 英文原文链接:https://www.analyticsvidhya.com/blog/2016/03/complete-g ...

  5. 【深度学习篇】--神经网络中的调优一,超参数调优和Early_Stopping

    一.前述 调优对于模型训练速度,准确率方面至关重要,所以本文对神经网络中的调优做一个总结. 二.神经网络超参数调优 1.适当调整隐藏层数对于许多问题,你可以开始只用一个隐藏层,就可以获得不错的结果,比 ...

  6. XGBoost参数调优完全指南

    简介 如果你的预测模型表现得有些不尽如人意,那就用XGBoost吧.XGBoost算法现在已经成为很多数据工程师的重要武器.它是一种十分精致的算法,可以处理各种不规则的数据.构造一个使用XGBoost ...

  7. xgboost 参数调优指南

    一.XGBoost的优势 XGBoost算法可以给预测模型带来能力的提升.当我对它的表现有更多了解的时候,当我对它的高准确率背后的原理有更多了解的时候,我发现它具有很多优势: 1 正则化 标准GBDT ...

  8. XGBoost模型的参数调优

    XGBoost算法在实际运行的过程中,可以通过以下要点进行参数调优: (1)添加正则项: 在模型参数中添加正则项,或加大正则项的惩罚力度,即通过调整加权参数,从而避免模型出现过拟合的情况. (2)控制 ...

  9. 评价指标的局限性、ROC曲线、余弦距离、A/B测试、模型评估的方法、超参数调优、过拟合与欠拟合

    1.评价指标的局限性 问题1 准确性的局限性 准确率是分类问题中最简单也是最直观的评价指标,但存在明显的缺陷.比如,当负样本占99%时,分类器把所有样本都预测为负样本也可以获得99%的准确率.所以,当 ...

随机推荐

  1. Linq 常用方法解释

    /// <summary> /// linq /// </summary> public class Linq { /// <summary> /// 测试 /// ...

  2. Linux访问windows共享(samba/smbclient/smbfs/cifs)

    samba是一个实现不同操作系统之间文件共享和打印机共享的一种SMB协议的免费软件.●安装samba,samba-client和cifs-utils.x86_64此步将自动安装好相关依赖包:samba ...

  3. 开关灯问题 BulbSwitch

    2018-06-17 11:54:51 开关电灯问题是一个比较经典的趣味数学题,本文中主要介绍其中的一些常见情况. 一.Bulb Switch 问题描述: 问题求解: 初始状态:off, off, o ...

  4. Liebig's Barrels CodeForces - 985C (贪心)

    链接 大意:给定$nk$块木板, 要制作$n$个$k$块板的桶, 要求任意两桶容积差不超过$l$, 每个桶的容积为最短木板长, 输出$n$个桶的最大容积和 假设最短板长$m$, 显然最后桶的体积都在$ ...

  5. CentOS7系统更换YUM Repo源

    CentOS7系统更换YUM Repo源 备份原镜像 sudo mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.re ...

  6. 数据库故障诊断(Troubleshooting)之性能问题导致的数据库严重故障案例之一

    好久不来这里写东西,今天有点时间,来这里写点最近遇到的事情.前段时间,某电信业务用户因某核心生产库最近多次宕机重启,多方人员介入无果后,给我发来了邮件,大概意思就是现在该问题已经造成了比较严重的后果, ...

  7. 用Maven创建第一个web项目

    http://www.cnblogs.com/leiOOlei/p/3361633.html 一.创建项目 1.Eclipse中用Maven创建项目 上图中Next 2.继续Next 3.选maven ...

  8. 封装一个简单的原生js焦点轮播图插件

    轮播图实现的效果为,鼠标移入左右箭头会出现,可以点击切换图片,下面的小圆点会跟随,可以循环播放(为了方便理解,没有补2张图做无缝轮播).本篇文章的主要目的是分享封装插件的思路. 轮播图我一开始是写成非 ...

  9. git and github问题集锦

    本人遇到的:

  10. selenium(五)伪造浏览器

    简介: 这个就比较好玩了,大家还记得以前的QQ小尾巴么?还有百度贴吧的小尾巴,就是那个来自***的iphone7,这个功能. 这个功能是基于浏览器的user-agent功能实现的. 还是httpbin ...