转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

使用OPENCV训练手写数字识别分类器

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

    1. #include "stdafx.h"
    2. #include <fstream>
    3. #include "opencv2/opencv.hpp"
    4. #include <vector>
    5. using namespace std;
    6. using namespace cv;
    7. #define SHOW_PROCESS 0
    8. #define ON_STUDY 0
    9. class NumTrainData
    10. {
    11. public:
    12. NumTrainData()
    13. {
    14. memset(data, 0, sizeof(data));
    15. result = -1;
    16. }
    17. public:
    18. float data[64];
    19. int result;
    20. };
    21. vector<NumTrainData> buffer;
    22. int featureLen = 64;
    23. void swapBuffer(char* buf)
    24. {
    25. char temp;
    26. temp = *(buf);
    27. *buf = *(buf+3);
    28. *(buf+3) = temp;
    29. temp = *(buf+1);
    30. *(buf+1) = *(buf+2);
    31. *(buf+2) = temp;
    32. }
    33. void GetROI(Mat& src, Mat& dst)
    34. {
    35. int left, right, top, bottom;
    36. left = src.cols;
    37. right = 0;
    38. top = src.rows;
    39. bottom = 0;
    40. //Get valid area
    41. for(int i=0; i<src.rows; i++)
    42. {
    43. for(int j=0; j<src.cols; j++)
    44. {
    45. if(src.at<uchar>(i, j) > 0)
    46. {
    47. if(j<left) left = j;
    48. if(j>right) right = j;
    49. if(i<top) top = i;
    50. if(i>bottom) bottom = i;
    51. }
    52. }
    53. }
    54. //Point center;
    55. //center.x = (left + right) / 2;
    56. //center.y = (top + bottom) / 2;
    57. int width = right - left;
    58. int height = bottom - top;
    59. int len = (width < height) ? height : width;
    60. //Create a squre
    61. dst = Mat::zeros(len, len, CV_8UC1);
    62. //Copy valid data to squre center
    63. Rect dstRect((len - width)/2, (len - height)/2, width, height);
    64. Rect srcRect(left, top, width, height);
    65. Mat dstROI = dst(dstRect);
    66. Mat srcROI = src(srcRect);
    67. srcROI.copyTo(dstROI);
    68. }
    69. int ReadTrainData(int maxCount)
    70. {
    71. //Open image and label file
    72. const char fileName[] = "../res/train-images.idx3-ubyte";
    73. const char labelFileName[] = "../res/train-labels.idx1-ubyte";
    74. ifstream lab_ifs(labelFileName, ios_base::binary);
    75. ifstream ifs(fileName, ios_base::binary);
    76. if( ifs.fail() == true )
    77. return -1;
    78. if( lab_ifs.fail() == true )
    79. return -1;
    80. //Read train data number and image rows / cols
    81. char magicNum[4], ccount[4], crows[4], ccols[4];
    82. ifs.read(magicNum, sizeof(magicNum));
    83. ifs.read(ccount, sizeof(ccount));
    84. ifs.read(crows, sizeof(crows));
    85. ifs.read(ccols, sizeof(ccols));
    86. int count, rows, cols;
    87. swapBuffer(ccount);
    88. swapBuffer(crows);
    89. swapBuffer(ccols);
    90. memcpy(&count, ccount, sizeof(count));
    91. memcpy(&rows, crows, sizeof(rows));
    92. memcpy(&cols, ccols, sizeof(cols));
    93. //Just skip label header
    94. lab_ifs.read(magicNum, sizeof(magicNum));
    95. lab_ifs.read(ccount, sizeof(ccount));
    96. //Create source and show image matrix
    97. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    98. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    99. Mat img, dst;
    100. char label = 0;
    101. Scalar templateColor(255, 0, 255 );
    102. NumTrainData rtd;
    103. //int loop = 1000;
    104. int total = 0;
    105. while(!ifs.eof())
    106. {
    107. if(total >= count)
    108. break;
    109. total++;
    110. cout << total << endl;
    111. //Read label
    112. lab_ifs.read(&label, 1);
    113. label = label + '0';
    114. //Read source data
    115. ifs.read((char*)src.data, rows * cols);
    116. GetROI(src, dst);
    117. #if(SHOW_PROCESS)
    118. //Too small to watch
    119. img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);
    120. resize(dst, img, img.size());
    121. stringstream ss;
    122. ss << "Number " << label;
    123. string text = ss.str();
    124. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    125. //imshow("img", img);
    126. #endif
    127. rtd.result = label;
    128. resize(dst, temp, temp.size());
    129. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    130. for(int i = 0; i<8; i++)
    131. {
    132. for(int j = 0; j<8; j++)
    133. {
    134. rtd.data[ i*8 + j] = temp.at<uchar>(i, j);
    135. }
    136. }
    137. buffer.push_back(rtd);
    138. //if(waitKey(0)==27) //ESC to quit
    139. //  break;
    140. maxCount--;
    141. if(maxCount == 0)
    142. break;
    143. }
    144. ifs.close();
    145. lab_ifs.close();
    146. return 0;
    147. }
    148. void newRtStudy(vector<NumTrainData>& trainData)
    149. {
    150. int testCount = trainData.size();
    151. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    152. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    153. for (int i= 0; i< testCount; i++)
    154. {
    155. NumTrainData td = trainData.at(i);
    156. memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));
    157. res.at<unsigned int>(i, 0) = td.result;
    158. }
    159. /////////////START RT TRAINNING//////////////////
    160. CvRTrees forest;
    161. CvMat* var_importance = 0;
    162. forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
    163. CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
    164. forest.save( "new_rtrees.xml" );
    165. }
    166. int newRtPredict()
    167. {
    168. CvRTrees forest;
    169. forest.load( "new_rtrees.xml" );
    170. const char fileName[] = "../res/t10k-images.idx3-ubyte";
    171. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
    172. ifstream lab_ifs(labelFileName, ios_base::binary);
    173. ifstream ifs(fileName, ios_base::binary);
    174. if( ifs.fail() == true )
    175. return -1;
    176. if( lab_ifs.fail() == true )
    177. return -1;
    178. char magicNum[4], ccount[4], crows[4], ccols[4];
    179. ifs.read(magicNum, sizeof(magicNum));
    180. ifs.read(ccount, sizeof(ccount));
    181. ifs.read(crows, sizeof(crows));
    182. ifs.read(ccols, sizeof(ccols));
    183. int count, rows, cols;
    184. swapBuffer(ccount);
    185. swapBuffer(crows);
    186. swapBuffer(ccols);
    187. memcpy(&count, ccount, sizeof(count));
    188. memcpy(&rows, crows, sizeof(rows));
    189. memcpy(&cols, ccols, sizeof(cols));
    190. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    191. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    192. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    193. Mat img, dst;
    194. //Just skip label header
    195. lab_ifs.read(magicNum, sizeof(magicNum));
    196. lab_ifs.read(ccount, sizeof(ccount));
    197. char label = 0;
    198. Scalar templateColor(255, 0, 0);
    199. NumTrainData rtd;
    200. int right = 0, error = 0, total = 0;
    201. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
    202. while(ifs.good())
    203. {
    204. //Read label
    205. lab_ifs.read(&label, 1);
    206. label = label + '0';
    207. //Read data
    208. ifs.read((char*)src.data, rows * cols);
    209. GetROI(src, dst);
    210. //Too small to watch
    211. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
    212. resize(dst, img, img.size());
    213. rtd.result = label;
    214. resize(dst, temp, temp.size());
    215. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    216. for(int i = 0; i<8; i++)
    217. {
    218. for(int j = 0; j<8; j++)
    219. {
    220. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
    221. }
    222. }
    223. if(total >= count)
    224. break;
    225. char ret = (char)forest.predict(m);
    226. if(ret == label)
    227. {
    228. right++;
    229. if(total <= 5000)
    230. right_1++;
    231. else
    232. right_2++;
    233. }
    234. else
    235. {
    236. error++;
    237. if(total <= 5000)
    238. error_1++;
    239. else
    240. error_2++;
    241. }
    242. total++;
    243. #if(SHOW_PROCESS)
    244. stringstream ss;
    245. ss << "Number " << label << ", predict " << ret;
    246. string text = ss.str();
    247. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    248. imshow("img", img);
    249. if(waitKey(0)==27) //ESC to quit
    250. break;
    251. #endif
    252. }
    253. ifs.close();
    254. lab_ifs.close();
    255. stringstream ss;
    256. ss << "Total " << total << ", right " << right <<", error " << error;
    257. string text = ss.str();
    258. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    259. imshow("img", img);
    260. waitKey(0);
    261. return 0;
    262. }
    263. void newSvmStudy(vector<NumTrainData>& trainData)
    264. {
    265. int testCount = trainData.size();
    266. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    267. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    268. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    269. for (int i= 0; i< testCount; i++)
    270. {
    271. NumTrainData td = trainData.at(i);
    272. memcpy(m.data, td.data, featureLen*sizeof(float));
    273. normalize(m, m);
    274. memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));
    275. res.at<unsigned int>(i, 0) = td.result;
    276. }
    277. /////////////START SVM TRAINNING//////////////////
    278. CvSVM svm = CvSVM();
    279. CvSVMParams param;
    280. CvTermCriteria criteria;
    281. criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
    282. param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
    283. svm.train(data, res, Mat(), Mat(), param);
    284. svm.save( "SVM_DATA.xml" );
    285. }
    286. int newSvmPredict()
    287. {
    288. CvSVM svm = CvSVM();
    289. svm.load( "SVM_DATA.xml" );
    290. const char fileName[] = "../res/t10k-images.idx3-ubyte";
    291. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
    292. ifstream lab_ifs(labelFileName, ios_base::binary);
    293. ifstream ifs(fileName, ios_base::binary);
    294. if( ifs.fail() == true )
    295. return -1;
    296. if( lab_ifs.fail() == true )
    297. return -1;
    298. char magicNum[4], ccount[4], crows[4], ccols[4];
    299. ifs.read(magicNum, sizeof(magicNum));
    300. ifs.read(ccount, sizeof(ccount));
    301. ifs.read(crows, sizeof(crows));
    302. ifs.read(ccols, sizeof(ccols));
    303. int count, rows, cols;
    304. swapBuffer(ccount);
    305. swapBuffer(crows);
    306. swapBuffer(ccols);
    307. memcpy(&count, ccount, sizeof(count));
    308. memcpy(&rows, crows, sizeof(rows));
    309. memcpy(&cols, ccols, sizeof(cols));
    310. Mat src = Mat::zeros(rows, cols, CV_8UC1);
    311. Mat temp = Mat::zeros(8, 8, CV_8UC1);
    312. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    313. Mat img, dst;
    314. //Just skip label header
    315. lab_ifs.read(magicNum, sizeof(magicNum));
    316. lab_ifs.read(ccount, sizeof(ccount));
    317. char label = 0;
    318. Scalar templateColor(255, 0, 0);
    319. NumTrainData rtd;
    320. int right = 0, error = 0, total = 0;
    321. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
    322. while(ifs.good())
    323. {
    324. //Read label
    325. lab_ifs.read(&label, 1);
    326. label = label + '0';
    327. //Read data
    328. ifs.read((char*)src.data, rows * cols);
    329. GetROI(src, dst);
    330. //Too small to watch
    331. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
    332. resize(dst, img, img.size());
    333. rtd.result = label;
    334. resize(dst, temp, temp.size());
    335. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
    336. for(int i = 0; i<8; i++)
    337. {
    338. for(int j = 0; j<8; j++)
    339. {
    340. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
    341. }
    342. }
    343. if(total >= count)
    344. break;
    345. normalize(m, m);
    346. char ret = (char)svm.predict(m);
    347. if(ret == label)
    348. {
    349. right++;
    350. if(total <= 5000)
    351. right_1++;
    352. else
    353. right_2++;
    354. }
    355. else
    356. {
    357. error++;
    358. if(total <= 5000)
    359. error_1++;
    360. else
    361. error_2++;
    362. }
    363. total++;
    364. #if(SHOW_PROCESS)
    365. stringstream ss;
    366. ss << "Number " << label << ", predict " << ret;
    367. string text = ss.str();
    368. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    369. imshow("img", img);
    370. if(waitKey(0)==27) //ESC to quit
    371. break;
    372. #endif
    373. }
    374. ifs.close();
    375. lab_ifs.close();
    376. stringstream ss;
    377. ss << "Total " << total << ", right " << right <<", error " << error;
    378. string text = ss.str();
    379. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    380. imshow("img", img);
    381. waitKey(0);
    382. return 0;
    383. }
    384. int main( int argc, char *argv[] )
    385. {
    386. #if(ON_STUDY)
    387. int maxCount = 60000;
    388. ReadTrainData(maxCount);
    389. //newRtStudy(buffer);
    390. newSvmStudy(buffer);
    391. #else
    392. //newRtPredict();
    393. newSvmPredict();
    394. #endif
    395. return 0;
    396. }
    397. //from: http://blog.csdn.net/yangtrees/article/details/7458466

学习OpenCV——SVM 手写数字检测的更多相关文章

  1. 基于opencv的手写数字识别(MFC,HOG,SVM)

    参考了秋风细雨的文章:http://blog.csdn.net/candyforever/article/details/8564746 花了点时间编写出了程序,先看看效果吧. 识别效果大概都能正确. ...

  2. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  3. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  4. 基于opencv的手写数字字符识别

    摘要 本程序主要参照论文,<基于OpenCV的脱机手写字符识别技术>实现了,对于手写阿拉伯数字的识别工作.识别工作分为三大步骤:预处理,特征提取,分类识别.预处理过程主要找到图像的ROI部 ...

  5. mnist手写数字检测

    # -*- coding: utf-8 -*- """ Created on Tue Apr 23 06:16:04 2019 @author: 92958 " ...

  6. 简单HOG+SVM mnist手写数字分类

    使用工具 :VS2013 + OpenCV 3.1 数据集:minst 训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml 数据准备 train-images- ...

  7. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  8. SVM学习笔记(二)----手写数字识别

    引言 上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的.考 ...

  9. 手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)

    @ 目录 前言 一.OpenCV DNN模块 1.OpenCV DNN简介 2.LabVIEW中DNN模块函数 二.TensorFlow pb文件的生成和调用 1.TensorFlow2 Keras模 ...

随机推荐

  1. Ubuntu Gnome 14.04.2 lts 折腾笔记

    unity感觉不爽,于是来折腾gnome3 = = 首先去官网下载ubuntu gnome 14.04.2 lts的包(种子:http://cdimage.ubuntu.com/ubuntu-gnom ...

  2. jsonkit mrc于arc混编

  3. IOS 蓝牙相关-app作为外设被连接的实现(3)

    再上一节说了app作为central连接peripheral的情况,这一节介绍如何使用app发布一个peripheral,给其他的central连接 还是这张图,central模式用的都是左边的类,而 ...

  4. vim operation

    note:  转自 www.quora.com ,很好的网站. 具体链接如下: https://www.quora.com/What-are-some-impressive-demos-of-Vim- ...

  5. 【hihoCoder】1049.后序遍历

    问题:http://hihocoder.com/problemset/problem/1049?sid=767510 已知一棵二叉树的前序遍历及中序遍历结果,求后序遍历结果 思路: 前序:根-左子树- ...

  6. 安装redis时遇到zmalloc.h:50:31: 致命错误:jemalloc/jemalloc.h:没有那个文件或目录

    参考博文,http://www.phperz.com/article/14/1219/42002.html 解决办法 make MALLOC=libc

  7. zju(4)使用busybox制作根文件系统

    1.实验目的 1.学习和掌握busybox相关知识及应用: 2.学会使用交叉编译器定制一个busybox: 3.利用该busybox制作一个文件系统: 4.熟悉根文件系统组织结构: 5.定制.编译ra ...

  8. 数组遍历map和each使用

    <body> <input type="/> <input type="/> <input type="/> </b ...

  9. 已知一个日期和天数, 求多少天后的日期(是那个超时代码的AC版)

    #include <stdio.h> #include <string.h> ; int judge_year(int x) { == || x % == && ...

  10. 读过的laravel文章

    Laravel 中使用 JWT(Json Web Token) 实现基于API的用户认证 http://www.tuicool.com/articles/IRJnaa api token https: ...