简单HOG+SVM mnist手写数字分类
使用工具 :VS2013 + OpenCV 3.1
数据集:minst
训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml
数据准备
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
首先我们利用matlab将数据转换成 .bmp 图片格式
fid_image=fopen('train-images.idx3-ubyte','r');
fid_label=fopen('train-labels.idx1-ubyte','r');
% Read the first 16 Bytes
magicnumber=fread(fid_image,4);
size=fread(fid_image,4);
row=fread(fid_image,4);
col=fread(fid_image,4);
% Read the first 8 Bytes
extra=fread(fid_label,8);
% Read labels related to images
imageIndex=fread(fid_label);
Num=length(imageIndex);
% Count repeat times of 0 to 9
cnt=zeros(1,10);
for k=1:Num
image=fread(fid_image,[max(row),max(col)]); % Get image data
val=imageIndex(k); % Get value of image
for i=0:9
if val==i
cnt(val+1)=cnt(val+1)+1;
end
end
if cnt(val+1)<10
str=[num2str(val),'_000',num2str(cnt(val+1)),'.bmp'];
elseif cnt(val+1)<100
str=[num2str(val),'_00',num2str(cnt(val+1)),'.bmp'];
elseif cnt(val+1)<1000
str=[num2str(val),'_0',num2str(cnt(val+1)),'.bmp'];
else
str=[num2str(val),'_',num2str(cnt(val+1)),'.bmp'];
end
imwrite(image',str);
end
fclose(fid_image);
fclose(fid_label);
然后使用cmd指令写入图片路径: dir /b/s/p/w *.bmp > num.txt 添加标签,如下图
然后打乱样本顺序。
训练
int main0()
{
vector<string> img_path;//输入文件名变量
vector<int> img_catg;
int nLine = ;
string line;
size_t pos;
ifstream svm_data("./train-images/random.txt");//训练样本图片的路径都写在这个txt文件中,使用bat批处理文件可以得到这个txt文件
unsigned long n;
while (svm_data)//将训练样本文件依次读取进来
{
if (getline(svm_data, line))
{
nLine++;
pos = line.find_last_of(' ');
img_path.push_back(line.substr(, pos));//图像路径
img_catg.push_back(atoi(line.substr(pos + ).c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错
}
} svm_data.close();//关闭文件
int nImgNum = nLine; //nImgNum是样本数量,只有文本行数的一半,另一半是标签
cv::Mat data_mat(nImgNum, , CV_32FC1);//第二个参数,即矩阵的列是由下面的descriptors的大小决定的,可以由descriptors.size()得到,且对于不同大小的输入训练图片,这个值是不同的
data_mat.setTo(cv::Scalar());
//类型矩阵,存储每个样本的类型标志
cv::Mat res_mat(nImgNum, , CV_32S);
res_mat.setTo(cv::Scalar());
cv::Mat src;
cv::Mat trainImg(cv::Size(, ), , );//需要分析的图片,这里默认设定图片是28*28大小,所以上面定义了324,如果要更改图片大小,可以先用debug查看一下descriptors是多少,然后设定好再运行 //处理HOG特征
for (string::size_type i = ; i != img_path.size(); i++)
{
src = cv::imread(img_path[i].c_str(), );
if (src.data == NULL)//if (src == NULL)
{
cout << " can not load the image: " << img_path[i].c_str() << endl;
continue;
} //cout << " 处理: " << img_path[i].c_str() << endl; cv::resize(src, trainImg, trainImg.size());
cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(, ), cv::Size(, ), cv::Size(, ), cv::Size(, ), );
vector<float>descriptors;//存放结果
hog->compute(trainImg, descriptors, cv::Size(, ), cv::Size(, )); //Hog特征计算
//cout << "HOG dims: " << descriptors.size() << endl;
n = ;
for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
{
//cvmSet(data_mat, i, n, *iter);
data_mat.at<float>(i, n) = *iter;//存储HOG特征
n++;
}
//cvmSet(res_mat, i, 0, img_catg[i]);
res_mat.at<int>(i, ) = img_catg[i];
//cout << " 处理完毕: " << img_path[i].c_str() << " " << img_catg[i] << endl;
}
cout << "computed features!" << endl; cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();//新建一个SVM
svm->setType(cv::ml::SVM::C_SVC);
svm->setKernel(cv::ml::SVM::LINEAR);
svm->setC();
//-------------------不使用参数优化-------------------------//
svm->setTermCriteria(cv::TermCriteria(CV_TERMCRIT_EPS, , FLT_EPSILON));
svm->train(data_mat, cv::ml::ROW_SAMPLE, res_mat);//训练数据
//-------------------参数优化-------------------------//
//svm->setTermCriteria = cv::TermCriteria(cv::TermCriteria::MAX_ITER, (int)1e7, 1e-6);
//cv::Ptr<cv::ml::TrainData> td = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, res_mat);
//svm->trainAuto(td, 10); //保存训练好的分类器
svm->save("HOG_SVM_DATA.xml");
cout << "saved model!" << endl;
//检测样本
cv::Mat test;//IplImage *test;
char result[];
vector<string> img_test_path;
vector<int> img_test_catg;
int coorect = ;
ifstream img_tst("./test-images/random.txt"); //加载需要预测的图片集合,这个文本里存放的是图片全路径,不要标签
while (img_tst)
{
if (getline(img_tst, line))
{
pos = line.find_last_of(' ');
img_test_catg.push_back(atoi(line.substr(pos + ).c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错
img_test_path.push_back(line.substr(, pos));//图像路径
}
}
img_tst.close(); ofstream predict_txt("SVM_PREDICT.txt");//把预测结果存储在这个文本中
for (string::size_type j = ; j != img_test_path.size(); j++)//依次遍历所有的待检测图片
{
test = cv::imread(img_test_path[j].c_str(), );
if (test.data == NULL)//test == NULL
{
cout << " can not load the image: " << img_test_path[j].c_str() << endl;
continue;
}
cv::Mat trainTempImg(cv::Size(, ), , );
trainTempImg.setTo(cv::Scalar());
cv::resize(test, trainTempImg, trainTempImg.size());
cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(, ), cv::Size(, ), cv::Size(, ), cv::Size(, ), );
vector<float>descriptors;//结果数组
hog->compute(trainTempImg, descriptors, cv::Size(, ), cv::Size(, ));
//cout << "HOG dims: " << descriptors.size() << endl;
cv::Mat SVMtrainMat(, descriptors.size(), CV_32FC1);
int n = ;
for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
{
SVMtrainMat.at<float>(, n) = *iter;
n++;
} int ret = svm->predict(SVMtrainMat);//检测结果
if (ret == img_test_catg[j])
coorect++;
sprintf(result, "%s %d\r\n", img_test_path[j].c_str(), ret);
predict_txt << result; //输出检测结果到文本
}
predict_txt.close();
cout << coorect* / img_test_path.size() << "%" << endl;
return ;
}
测试
int main(int argc, char* argv[])
{
cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();
svm = cv::ml::SVM::load("HOG_SVM_DATA.xml");;//加载训练好的xml文件,这里训练的是10K个手写数字
//检测样本
cv::Mat test;
char result[]; //存放预测结果
test = cv::imread("6.bmp", ); //待预测图片,用系统自带的画图工具随便手写
if (!test.data)
{
MessageBox(NULL, TEXT("待预测图像不存在!"), TEXT("提示"), MB_ICONWARNING);
return -;
}
cv::Mat trainTempImg(cv::Size(, ), , );
trainTempImg.setTo(cv::Scalar());
cv::resize(test, trainTempImg, trainTempImg.size());
cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(, ), cv::Size(, ), cv::Size(, ), cv::Size(, ), );
vector<float>descriptors;//结果数组
hog->compute(trainTempImg, descriptors, cv::Size(, ), cv::Size(, ));
//cout << "HOG dims: " << descriptors.size() << endl;
cv::Mat SVMtrainMat(, descriptors.size(), CV_32FC1);
int n = ;
for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
{
SVMtrainMat.at<float>(, n) = *iter;
n++;
}
int ret = svm->predict(SVMtrainMat);//检测结果
sprintf(result, "%d\r\n", ret);
cv::namedWindow("dst", );
cv::imshow("dst", test);
MessageBox(NULL, result, TEXT("预测结果"), MB_OK);
return ;
}
简单HOG+SVM mnist手写数字分类的更多相关文章
- MNIST手写数字分类simple版(03-2)
simple版本nn模型 训练手写数字处理 MNIST_data数据 百度网盘链接:https://pan.baidu.com/s/19lhmrts-vz0-w5wv2A97gg 提取码:cgnx ...
- mnist手写数字问题初体验
上一篇我们提到了回归问题中的梯度下降算法,而且我们知道线性模型只能解决简单的线性回归问题,对于高维图片,线性模型不能完成这样复杂的分类任务.那么是不是线性模型在离散值预测或图像分类问题中就没有用武之地 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- Tensorflow实现MNIST手写数字识别
之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Tensorflow-线性回归与手写数字分类
线性回归 步骤 构造线性回归数据 定义输入层 设计神经网络中间层 定义神经网络输出层 计算二次代价函数,构建梯度下降 进行训练,获取预测值 画图展示 代码 import tensorflow as t ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
随机推荐
- action spring 注入错误,如果检查各项注入都没有错误时,考虑struts 是否配置了namespace(如果你有多个namespace="/")
[ERROR] 2015-01-04 09:42:35,180 (CommonsLogger.java:38) - Exception occurred during processing reque ...
- js把一个数组插入到另一个数组的指定位置
var arr1 = ['a', 'b', 'c']; var arr2 = ['1', '2', '3']; // 把arr2 变成一个适合splice的数组(包含splice前2个参数的数组) a ...
- 2018.08.30 NOIP模拟 wall(模拟)
[问题描述] 万里长城是中国强大的标志,长城在古代的用途主要用于快速传递军事消息和抵御 外敌,在长城上的烽火台即可以作为藏兵的堡垒有可以来点燃狼烟传递消息. 现在有一段 万里长城,一共有 N 个烽火台 ...
- 2018.08.09洛谷P3959 宝藏(随机化贪心)
传送门 回想起了自己赛场上乱搜的20分. 好吧现在也就是写了一个随机化贪心就水过去了,不得不说随机化贪心大法好. 代码: #include<bits/stdc++.h> using nam ...
- 实现字符串函数,strlen(),strcpy(),strcmp(),strcat()
实现字符串函数,strlen(),strcpy(),strcmp(),strcat() #include<stdio.h> #include<stdlib.h> int my_ ...
- modelsim读写TXT文件
//open the file Initial Begin step_file = $fopen("F:/Company/Src/txt/step.v","r" ...
- 城市边界线预测(根据灯光指数)(PUL)
1.EXEALL.m function EXEALL(FilePath, FileName)%执行所有流程% FilePath: 文件夹所在路径% FileName: 文件夹名称 FullPath = ...
- 利用Project Tango进行室内三维建模 精度评定
coming soon 在Android开发基础上开发Tango应用 Android+Tango
- struts2从浅至深(三)拦截器
一:拦截器概述 Struts2中的很多功能都是由拦截器完成的. 是AOP编程思想的一种应用形式. 二:拦截器执行时机 interceptor表示 ...
- HDU1459 非常可乐(BFS) 2016-07-24 15:00 165人阅读 评论(0) 收藏
非常可乐 Problem Description 大家一定觉的运动以后喝可乐是一件很惬意的事情,但是seeyou却不这么认为.因为每次当seeyou买了可乐以后,阿牛就要求和seeyou一起分享这一瓶 ...