转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

使用OPENCV训练手写数字识别分类器

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

    1. #include "stdafx.h"
    2. #include <fstream>
    3. #include "opencv2/opencv.hpp"
    4. #include <vector>
    5. using namespace std;
    6. using namespace cv;
    7. #define SHOW_PROCESS 0
    8. #define ON_STUDY 0
    9. class NumTrainData
    10. {
    11. public:
    12. NumTrainData()
    13. {
    14. memset(data, 0, sizeof(data));
    15. result = -1;
    16. }
    17. public:
    18. float data[64];
    19. int result;
    20. };
    21. vector<NumTrainData> buffer;
    22. int featureLen = 64;
    23. void swapBuffer(char* buf)
    24. {
    25. char temp;
    26. temp = *(buf);
    27. *buf = *(buf+3);
    28. *(buf+3) = temp;
    29. temp = *(buf+1);
    30. *(buf+1) = *(buf+2);
    31. *(buf+2) = temp;
    32. }
    33. void GetROI(Mat& src, Mat& dst)
    34. {
    35. int left, right, top, bottom;
    36. left = src.cols;
    37. right = 0;
    38. top = src.rows;
    39. bottom = 0;
    40. //Get valid area
    41. for(int i=0; i<src.rows; i++)
    42. {
    43. for(int j=0; j<src.cols; j++)
    44. {
    45. if(src.at<uchar>(i, j) > 0)
    46. {
    47. if(j<left) left = j;
    48. if(j>right) right = j;
    49. if(i<top) top = i;
    50. if(i>bottom) bottom = i;
    51. }
    52. }
    53. }
    54. //Point center;
    55. //center.x = (left + right) / 2;
    56. //center.y = (top + bottom) / 2;
    57. int width = right - left;
    58. int height = bottom - top;
    59. int len = (width < height) ? height : width;
    60. //Create a squre
    61. dst = Mat::zeros(len, len, CV_8UC1);
    62. //Copy valid data to squre center
    63. Rect dstRect((len - width)/2, (len - height)/2, width, height);
    64. Rect srcRect(left, top, width, height);
    65. Mat dstROI = dst(dstRect);
    66. Mat srcROI = src(srcRect);
    67. srcROI.copyTo(dstROI);
    68. }
    69. int ReadTrainData(int maxCount)
    70. {
    71. //Open image and label file
    72. const char fileName[] = "../res/train-images.idx3-ubyte";
    73. const char labelFileName[] = "../res/train-labels.idx1-ubyte";
    74. ifstream lab_ifs(labelFileName, ios_base::binary);
    75. ifstream ifs(fileName, ios_base::binary);
    76. if( ifs.fail() == true )
    77. return -1;
    78. if( lab_ifs.fail() == true )
    79. return -1;
    80. //Read train data number and image rows / cols
    81. char magicNum[4], ccount[4], crows[4], ccols[4];
    82. ifs.read(magicNum, sizeof(magicNum));
    83. ifs.read(ccount, sizeof(ccount));
    84. ifs.read(crows, sizeof(crows));
    85. ifs.read(ccols, sizeof(ccols));
    86. int count, rows, cols;
    87. swapBuffer(ccount);
    88. swapBuffer(crows);
    89. swapBuffer(ccols);
    90. memcpy(&count, ccount, sizeof(count));
    91. memcpy(&rows, crows, sizeof(rows));
    92. memcpy(&cols, ccols, sizeof(cols));
    93. //Just skip label header
    94. lab_ifs.read(magicNum, sizeof(magicNum));
    95. lab_ifs.read(ccount, sizeof(ccount));
    96. //Create source and show image matrix
    97. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    98. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    99. Mat img, dst;
    100. char label = 0;
    101. Scalar templateColor(255, 0, 255 );
    102. NumTrainData rtd;
    103. //int loop = 1000;
    104. int total = 0;
    105. while(!ifs.eof())
    106. {
    107. if(total >= count)
    108. break;
    109. total++;
    110. cout << total << endl;
    111. //Read label
    112. lab_ifs.read(&label, 1);
    113. label = label + '0';
    114. //Read source data
    115. ifs.read((char*)src.data, rows * cols);
    116. GetROI(src, dst);
    117. #if(SHOW_PROCESS)
    118. //Too small to watch
    119. img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);
    120. resize(dst, img, img.size());
    121. stringstream ss;
    122. ss << "Number " << label;
    123. string text = ss.str();
    124. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    125. //imshow("img", img);
    126. #endif
    127. rtd.result = label;
    128. resize(dst, temp, temp.size());
    129. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    130. for(int i = 0; i<8; i++)
    131. {
    132. for(int j = 0; j<8; j++)
    133. {
    134. rtd.data[ i*8 + j] = temp.at<uchar>(i, j);
    135. }
    136. }
    137. buffer.push_back(rtd);
    138. //if(waitKey(0)==27) //ESC to quit
    139. //  break;
    140. maxCount--;
    141. if(maxCount == 0)
    142. break;
    143. }
    144. ifs.close();
    145. lab_ifs.close();
    146. return 0;
    147. }
    148. void newRtStudy(vector<NumTrainData>& trainData)
    149. {
    150. int testCount = trainData.size();
    151. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    152. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    153. for (int i= 0; i< testCount; i++)
    154. {
    155. NumTrainData td = trainData.at(i);
    156. memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));
    157. res.at<unsigned int>(i, 0) = td.result;
    158. }
    159. /////////////START RT TRAINNING//////////////////
    160. CvRTrees forest;
    161. CvMat* var_importance = 0;
    162. forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
    163. CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
    164. forest.save( "new_rtrees.xml" );
    165. }
    166. int newRtPredict()
    167. {
    168. CvRTrees forest;
    169. forest.load( "new_rtrees.xml" );
    170. const char fileName[] = "../res/t10k-images.idx3-ubyte";
    171. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
    172. ifstream lab_ifs(labelFileName, ios_base::binary);
    173. ifstream ifs(fileName, ios_base::binary);
    174. if( ifs.fail() == true )
    175. return -1;
    176. if( lab_ifs.fail() == true )
    177. return -1;
    178. char magicNum[4], ccount[4], crows[4], ccols[4];
    179. ifs.read(magicNum, sizeof(magicNum));
    180. ifs.read(ccount, sizeof(ccount));
    181. ifs.read(crows, sizeof(crows));
    182. ifs.read(ccols, sizeof(ccols));
    183. int count, rows, cols;
    184. swapBuffer(ccount);
    185. swapBuffer(crows);
    186. swapBuffer(ccols);
    187. memcpy(&count, ccount, sizeof(count));
    188. memcpy(&rows, crows, sizeof(rows));
    189. memcpy(&cols, ccols, sizeof(cols));
    190. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    191. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    192. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    193. Mat img, dst;
    194. //Just skip label header
    195. lab_ifs.read(magicNum, sizeof(magicNum));
    196. lab_ifs.read(ccount, sizeof(ccount));
    197. char label = 0;
    198. Scalar templateColor(255, 0, 0);
    199. NumTrainData rtd;
    200. int right = 0, error = 0, total = 0;
    201. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
    202. while(ifs.good())
    203. {
    204. //Read label
    205. lab_ifs.read(&label, 1);
    206. label = label + '0';
    207. //Read data
    208. ifs.read((char*)src.data, rows * cols);
    209. GetROI(src, dst);
    210. //Too small to watch
    211. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
    212. resize(dst, img, img.size());
    213. rtd.result = label;
    214. resize(dst, temp, temp.size());
    215. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    216. for(int i = 0; i<8; i++)
    217. {
    218. for(int j = 0; j<8; j++)
    219. {
    220. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
    221. }
    222. }
    223. if(total >= count)
    224. break;
    225. char ret = (char)forest.predict(m);
    226. if(ret == label)
    227. {
    228. right++;
    229. if(total <= 5000)
    230. right_1++;
    231. else
    232. right_2++;
    233. }
    234. else
    235. {
    236. error++;
    237. if(total <= 5000)
    238. error_1++;
    239. else
    240. error_2++;
    241. }
    242. total++;
    243. #if(SHOW_PROCESS)
    244. stringstream ss;
    245. ss << "Number " << label << ", predict " << ret;
    246. string text = ss.str();
    247. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    248. imshow("img", img);
    249. if(waitKey(0)==27) //ESC to quit
    250. break;
    251. #endif
    252. }
    253. ifs.close();
    254. lab_ifs.close();
    255. stringstream ss;
    256. ss << "Total " << total << ", right " << right <<", error " << error;
    257. string text = ss.str();
    258. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    259. imshow("img", img);
    260. waitKey(0);
    261. return 0;
    262. }
    263. void newSvmStudy(vector<NumTrainData>& trainData)
    264. {
    265. int testCount = trainData.size();
    266. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    267. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    268. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    269. for (int i= 0; i< testCount; i++)
    270. {
    271. NumTrainData td = trainData.at(i);
    272. memcpy(m.data, td.data, featureLen*sizeof(float));
    273. normalize(m, m);
    274. memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));
    275. res.at<unsigned int>(i, 0) = td.result;
    276. }
    277. /////////////START SVM TRAINNING//////////////////
    278. CvSVM svm = CvSVM();
    279. CvSVMParams param;
    280. CvTermCriteria criteria;
    281. criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
    282. param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
    283. svm.train(data, res, Mat(), Mat(), param);
    284. svm.save( "SVM_DATA.xml" );
    285. }
    286. int newSvmPredict()
    287. {
    288. CvSVM svm = CvSVM();
    289. svm.load( "SVM_DATA.xml" );
    290. const char fileName[] = "../res/t10k-images.idx3-ubyte";
    291. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
    292. ifstream lab_ifs(labelFileName, ios_base::binary);
    293. ifstream ifs(fileName, ios_base::binary);
    294. if( ifs.fail() == true )
    295. return -1;
    296. if( lab_ifs.fail() == true )
    297. return -1;
    298. char magicNum[4], ccount[4], crows[4], ccols[4];
    299. ifs.read(magicNum, sizeof(magicNum));
    300. ifs.read(ccount, sizeof(ccount));
    301. ifs.read(crows, sizeof(crows));
    302. ifs.read(ccols, sizeof(ccols));
    303. int count, rows, cols;
    304. swapBuffer(ccount);
    305. swapBuffer(crows);
    306. swapBuffer(ccols);
    307. memcpy(&count, ccount, sizeof(count));
    308. memcpy(&rows, crows, sizeof(rows));
    309. memcpy(&cols, ccols, sizeof(cols));
    310. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    311. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    312. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    313. Mat img, dst;
    314. //Just skip label header
    315. lab_ifs.read(magicNum, sizeof(magicNum));
    316. lab_ifs.read(ccount, sizeof(ccount));
    317. char label = 0;
    318. Scalar templateColor(255, 0, 0);
    319. NumTrainData rtd;
    320. int right = 0, error = 0, total = 0;
    321. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
    322. while(ifs.good())
    323. {
    324. //Read label
    325. lab_ifs.read(&label, 1);
    326. label = label + '0';
    327. //Read data
    328. ifs.read((char*)src.data, rows * cols);
    329. GetROI(src, dst);
    330. //Too small to watch
    331. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
    332. resize(dst, img, img.size());
    333. rtd.result = label;
    334. resize(dst, temp, temp.size());
    335. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    336. for(int i = 0; i<8; i++)
    337. {
    338. for(int j = 0; j<8; j++)
    339. {
    340. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
    341. }
    342. }
    343. if(total >= count)
    344. break;
    345. normalize(m, m);
    346. char ret = (char)svm.predict(m);
    347. if(ret == label)
    348. {
    349. right++;
    350. if(total <= 5000)
    351. right_1++;
    352. else
    353. right_2++;
    354. }
    355. else
    356. {
    357. error++;
    358. if(total <= 5000)
    359. error_1++;
    360. else
    361. error_2++;
    362. }
    363. total++;
    364. #if(SHOW_PROCESS)
    365. stringstream ss;
    366. ss << "Number " << label << ", predict " << ret;
    367. string text = ss.str();
    368. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    369. imshow("img", img);
    370. if(waitKey(0)==27) //ESC to quit
    371. break;
    372. #endif
    373. }
    374. ifs.close();
    375. lab_ifs.close();
    376. stringstream ss;
    377. ss << "Total " << total << ", right " << right <<", error " << error;
    378. string text = ss.str();
    379. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    380. imshow("img", img);
    381. waitKey(0);
    382. return 0;
    383. }
    384. int main( int argc, char *argv[] )
    385. {
    386. #if(ON_STUDY)
    387. int maxCount = 60000;
    388. ReadTrainData(maxCount);
    389. //newRtStudy(buffer);
    390. newSvmStudy(buffer);
    391. #else
    392. //newRtPredict();
    393. newSvmPredict();
    394. #endif
    395. return 0;
    396. }
    397. //from: http://blog.csdn.net/yangtrees/article/details/7458466

学习OpenCV——SVM 手写数字检测的更多相关文章

  1. 基于opencv的手写数字识别(MFC,HOG,SVM)

    参考了秋风细雨的文章:http://blog.csdn.net/candyforever/article/details/8564746 花了点时间编写出了程序,先看看效果吧. 识别效果大概都能正确. ...

  2. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  3. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  4. 基于opencv的手写数字字符识别

    摘要 本程序主要参照论文,<基于OpenCV的脱机手写字符识别技术>实现了,对于手写阿拉伯数字的识别工作.识别工作分为三大步骤:预处理,特征提取,分类识别.预处理过程主要找到图像的ROI部 ...

  5. mnist手写数字检测

    # -*- coding: utf-8 -*- """ Created on Tue Apr 23 06:16:04 2019 @author: 92958 " ...

  6. 简单HOG+SVM mnist手写数字分类

    使用工具 :VS2013 + OpenCV 3.1 数据集:minst 训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml 数据准备 train-images- ...

  7. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  8. SVM学习笔记(二)----手写数字识别

    引言 上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的.考 ...

  9. 手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)

    @ 目录 前言 一.OpenCV DNN模块 1.OpenCV DNN简介 2.LabVIEW中DNN模块函数 二.TensorFlow pb文件的生成和调用 1.TensorFlow2 Keras模 ...

随机推荐

  1. 创建需要计时器的windows service

    1.在VS中建立windows service后,应该添加一个安装程序. 2.在默认的Service1.cs设计界面右键,添加安装程序,生成ProjectInstaller.包含两个类serviceP ...

  2. 常见的sql语句 注意点及用法【区分mysql 和Sqlserver】

    如何判断在字符串字段中是否包含某个字符串 mysql:   url:http://www.springload.cn/springload/detail/399 mysql> SELECT * ...

  3. Iterator用法

    <% List<Emp> all=DAOFactory.getIEmpDAOInstance().findAll(keyWord); Itrator<Emp> iter= ...

  4. 用简单直白的方式讲解A星寻路算法原理

    很多游戏特别是rts,rpg类游戏,都需要用到寻路.寻路算法有深度优先搜索(DFS),广度优先搜索(BFS),A星算法等,而A星算法是一种具备启发性策略的算法,效率是几种算法中最高的,因此也成为游戏中 ...

  5. javascript使用两个逻辑非运算符(!!)的原因

    javascript使用两个逻辑非运算符(!!)的原因: 在有些代码中可能大家可能会注意到有些地方使用了两个逻辑非运算符,第一感觉就是没有必要,比如操作数是true的话,使用两个逻辑非的返回值还是tr ...

  6. [LintCode] Maximal Square 最大正方形

    Given a 2D binary matrix filled with 0's and 1's, find the largest square containing all 1's and ret ...

  7. Sqoop_mysql,hive,hdfs导入导出操作

    前言: 搭建环境,这里使用cdh版hadoop+hive+sqoop+mysql 下载 hadoop-2.5.0-cdh5.3.6.tar.gz hive-0.13.1-cdh5.3.6.tar.gz ...

  8. zk 隐藏网页文件后缀

    前台(test.zul): <a label="隐藏地址" href="/Bandbox/test.html"/> web.xml添加 <se ...

  9. 论meta name= viewport content= width=device-width initial-scale=1 minimum-scale=1 maximum-scale=1的作用

    一.先明白几个概念 phys.width: device-width: 一般我们所指的宽度width即为phys.width,而device-width又称为css-width. 其中我们可以获取ph ...

  10. 配置samba服务一例

    问题: 在/data/share目录下建立三个子目录public.training.devel用途如下 public目录用于存放公共数据,如公司的规章制度 training目录用于存放公司的技术培训资 ...