OpenCV + sklearnSVM 实现手写数字分割和识别
这学期机器学习考核方式以大作业的形式进行考核,而且只能使用一些传统的机器学习算法。
综合再三,选择了自己比较熟悉的MNIST数据集以及OpenCV来完成手写数字的分割和识别作为大作业。
1. 数据集准备
MNIST数据集是一个手写数字的数据库,包含60000张训练图片和10000张测试图片,每张图片大小为28x28像素,每张图片都是一个
灰度图,像素取值范围在0-255之间。
这里使用pytorch的torchvision.datasets模块来读取MNIST数据集。
from torchvision import datasets
mnist_set = datasets.MNIST(root="./MNIST", train=True, download=True)
具体参数说明请自行搜索。注意若donwload=True
,则torchvision会通过内置链接自动下载数据集,
但是有时会失效。因此可以自己去网络上下载并解压后排列成指定文件树,如下
MNIST
├── MNSIT
│ ├── raw
│ │ ├── t10k-images-idx3-ubyte.gz
│ │ ├── t10k-labels-idx1-ubyte.gz
│ │ ├── train-images-idx3-ubyte.gz
│ │ └── train-labels-idx1-ubyte.gz
然后使用如下语句去读取数据集
img, target = minst_set[0]
其中每个img类型为PILimage,target类型为int,代表该图片对应的数字。
但是在喂给SVM训练时需要的是[batch_size, data]大小的numpy数组,因此需要做一些预处理
x_, y_ = list(zip(*([(np.array(img).reshape(28*28), target) for img, target in mnist_set])))
上面的语句实现了将MNIST数据集转换成numpy数组的形式,其中x_是每个成员为[1, 784]的numpy数组,y_为对应的数字所组成的列表。
2. SVM训练
支持向量机(support vector machine,SVM)是经典的机器学习算法,其通过选取两个n维支持向量(support vector)之间的n维超平面来对两类对象进行二分类。而专注于分类的SVM又称作Support Vector Classification,SVC。
求解SVM是一个很复杂的问题,但是万幸的是sklearn中有封装的很好的模块,可以很简单的直接使用
from sklearn.svm import SVC
svc = SVC(kernel='rbf', C=1)
svc.fit(x_, y_)
其中fit接口接受两个参数,第一个参数为训练数据[batch_size, data],第二个参数为训练标签[batch_size,1]。
SVC的构造函数如下
SVC(C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape='ovr', random_state=None)
比较重要的参数有kernel,C,decison_function_shape等。
- kernel参数指定了核函数,常用的有linear,poly,rbf,sigmoid等。
- C为惩罚系数,C越大,对误分类的惩罚越大,模型越保守,C越小,对误分类的惩罚越小,模型越宽松,也就是较大的C在训练集上会有更高的正确率,较小的C会容许噪声的存在,泛化能力较强。
- decision_function_shape参数指定了决策函数的形状,ovr表示one-vs-rest,ovo表示one-vs-one,具体的意思可以网络查阅
4. 数字分割
数字分割是指将图像中的数字部分分割出来,然后一个一个喂给SVM进行分类
这里就是使用opencv对拍摄的图像进行轮廓提取后拟合外接矩形,借此来提取数字部分的ROI。
这里选择进行Canny边缘检测后去进行轮廓提取,然后拟合外接矩形,因为相较于直接二值化后去提取数字部分的ROI,
边缘检测对数字与纸张的边界更加敏感,即便在光照不均匀的情况下,也能较好的提取出数字的边缘。鲁棒性强。
5. 杂项与代码
这里还有一些杂项,比如保存模型,加载模型
使用pickle模块对训练好的模型对象进行序列化保存与加载,可以将训练好的模型保存到本地,以便后续使用。
最后贴出代码
代码
import os.path
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision import datasets
from torchvision import transforms
from sklearn import svm
from sklearn import preprocessing
from sklearnex import patch_sklearn
import pickle
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import learning_curve
'''
@brief 加载MNIST数据集并转换格式成二值图
@param train: 是否为训练集
@param data_enhance: 是否进行数据增强
@return 二值图集和标签集
'''
def LoadMnistDataset(train=True, data_enhance=False):
mnist_set = datasets.MNIST(root="./MNIST", train=train, download=True)
x_, y_ = list(zip(*([(np.array(img), target) for img, target in mnist_set])))
sets_raw = []
sets_r20 = []
sets_invr20 = []
y = []
y_r20 = []
y_invr20 = []
sets = []
matrix_r20 = cv2.getRotationMatrix2D((14, 14), 25, 1.0)
matrix_invr20 = cv2.getRotationMatrix2D((14, 14), -25, 1.0)
select = 0
for idx in range(len(x_)):
# 对图像进行二值化以及数据增强
_, img = cv2.threshold(x_[idx], 255, 255, cv2.THRESH_OTSU)
sets_raw.append(np.array(img.data).reshape(784))
y.append(y_[idx])
if data_enhance:
if select % 2 == 0:
img_r20 = ~cv2.warpAffine(~img, matrix_r20, (28, 28), borderValue=(255, 255, 255))
sets_r20.append(np.array(img_r20.data).reshape(784))
y_r20.append(y_[idx])
else:
img_invr20 = ~cv2.warpAffine(~img, matrix_invr20, (28, 28), borderValue=(255, 255, 255))
sets_invr20.append(np.array(img_invr20.data).reshape(784))
y_invr20.append(y_[idx])
select += 1
# 数据增强
sets = sets_raw + sets_r20 + sets_invr20
sets = np.array(sets)
print(sets.shape)
if data_enhance:
y = y + y_r20 + y_invr20
return sets, y
'''
@brief 保存SVM模型
@param svc_model: SVM模型
@param file_path: 模型保存路径,默认为./SVC
@return None
'''
def SaveSvcModel(svc_model, file_path="./SVC"):
with open(file_path, 'wb') as fs:
pickle.dump(svc_model, fs)
'''
@brief 加载SVM模型
@param file_path: 模型保存路径,默认为./SVC
@return SVM模型
'''
def LoadSvcModel(file_path="./SVC"):
if not os.path.exists(file_path):
assert "Model Do Not Exist"
with open(file_path, 'rb') as fs:
svc_model = pickle.load(fs)
return svc_model
'''
@brief 训练SVM模型
@param c: SVM参数C
@param enhance: 是否进行数据增强
@return acc: 在测试集上的准确率
svc_model: SVM模型
'''
def TrainSvc(c, enhance):
# 读取数据集,训练集及测试集
images_train, targets_train = LoadMnistDataset(train=True, data_enhance=enhance)
images_test, targets_test = LoadMnistDataset(train=False, data_enhance=enhance)
# 训练
svc_model = svm.SVC(C=c,kernel='rbf', decision_function_shape='ovr')
svc_model.fit(images_train, targets_train)
# 在测试集上测试准确度
res = svc_model.predict(images_test)
correct = (res == targets_test).sum()
accuracy = correct / len(images_test)
print(f"测试集上的准确率为{accuracy * 100}%")
return svc_model
'''
@brief 预处理比较粗的字体
@param image: 输入图像
@:param show: 是否显示预处理后的图像
@:param thresh: 二值化阈值
@return 预处理后的图像数据
'''
def PreProcessFatFont(image, show=False):
# 白底黑字转黑底白字
pre_ = ~image
# 转单通道灰度
pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)
# 二值化
_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)
# resize后添加黑色边框,亲测可提高识别率
pre_ = cv2.resize(pre_, (112, 112))
_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)
back = np.zeros((300, 300), np.uint8)
back[29:141, 29:141] = pre_
pre_ = back
if show:
cv2.imshow("show", pre_)
cv2.waitKey(0)
# 做一次开运算(腐蚀 + 膨胀)
kernel = np.ones((2, 2), np.uint8)
pre_ = cv2.erode(pre_, kernel, iterations=1)
kernel = np.ones((3, 3), np.uint8)
pre_ = cv2.dilate(pre_, kernel, iterations=1)
# 第二次resize
pre_ = cv2.resize(pre_, (56, 56))
_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)
# 做一次开运算(腐蚀 + 膨胀)
kernel = np.ones((2, 2), np.uint8)
pre_ = cv2.erode(pre_, kernel, iterations=1)
kernel = np.ones((3, 3), np.uint8)
pre_ = cv2.dilate(pre_, kernel, iterations=1)
# resize成输入规格
pre_ = cv2.resize(pre_, (28, 28))
_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)
# 转换为SVM的输入格式
pre_ = np.array(pre_).flatten().reshape(1, -1)
return pre_
'''
@brief 预处理细的字体
@param image: 输入图像
@param show: 是否显示预处理后的图像
@param thresh: 二值化阈值
@return 预处理后的图像数据
'''
def PreProcessThinFont(image, show=False):
# 白底黑字转黑底白字
pre_ = ~image
# 转灰度图
pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)
# 增加黑色边框
pre_ = cv2.resize(pre_, (112, 112))
_, pre_ = cv2.threshold(pre_,thresh=0, maxval=255, type=cv2.THRESH_OTSU)
back = np.zeros((170, 170), dtype=np.uint8) # 这里不指明类型会导致后续矩阵强转为float64,无法使用大津法阈值
back[29:141, 29:141] = pre_
pre_ = back
if show:
cv2.imshow("show", pre_)
cv2.waitKey(0)
# 对细字体先膨胀一下
kernel = np.ones((3, 3), np.uint8)
pre_ = cv2.dilate(pre_, kernel, iterations=2)
# 第二次resize
pre_ = cv2.resize(pre_, (56, 56))
_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)
# 做一次开运算(腐蚀 + 膨胀)
kernel = np.ones((2, 2), np.uint8)
pre_ = cv2.erode(pre_, kernel, iterations=1)
kernel = np.ones((3, 3), np.uint8)
pre_ = cv2.dilate(pre_, kernel, iterations=1)
# resize成输入规格
pre_ = cv2.resize(pre_, (28, 28))
_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)
# 转换为SVM输入格式
pre_ = np.array(pre_).flatten().reshape(1, -1)
return pre_
'''
@brief 在空白背景上显示提取出的roi
@param res_list: roi列表
@return None
'''
def ShowRoi(res_list):
back = 255 * np.ones((1000, 1500, 3), dtype=np.uint8)
# 图片x轴偏移量
tlx = 0
for roi in res_list:
if tlx + roi.shape[1] > back.shape[1]:
break
# 每次在原图上加上一个roi
back[0:roi.shape[0], tlx:tlx + roi.shape[1], :] = roi
tlx += roi.shape[1]
cv2.imshow("show", back)
cv2.waitKey(0)
'''
@brief 寻找数字轮廓并提取roi
@param src: 输入图像
@param thin: 是否为细字体
@param thresh: 二值化阈值
@return roi列表
'''
def FindNumbers(src, thin=True):
# 拷贝
dst = src.copy()
paint = src.copy()
roi = src.copy()
dst = ~dst
# 预处理
paint = cv2.resize(paint, (448, 448))
dst = cv2.resize(dst, (448, 448))
# 记录缩放比例,后来看这一步好像没啥意义
fx = src.shape[1] / 448
fy = src.shape[0] / 448
# 转单通道
dst = cv2.cvtColor(dst, cv2.COLOR_BGR2GRAY)
# 边缘检测后二值化,直接二值化的话由于采光不同的原因灰度直方图峰与峰之间可能会差距过大,导致二值图的分割不准确
# 而边缘检测对像素突变更加敏感,因此采用Canny边缘检测后二值化
cv2.Canny(dst, 200, 200, dst)
# 对于平常笔写的字太细,膨胀一下
if thin:
kernel = np.ones((5, 5), np.uint8)
dst = cv2.dilate(dst, kernel, iterations=1)
# 寻找外围轮廓
contours, _ = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 提取roi
roi_list = []
rect_list = []
for contour in contours:
rect = cv2.boundingRect(contour)
if not ((rect[2] * rect[3] < 400 or rect[2] * rect[3] > 448 * 448 / 2.5) or (rect[3] < rect[2])):
cv2.rectangle(paint, rect, (255, 0, 0), 1)
x_min = rect[0] * fx
x_max = (rect[0] + rect[2]) * fx
y_min = rect[1] * fy
y_max = (rect[1] + rect[3]) * fy
roi_list.append(roi[int(y_min):int(y_max), int(x_min):int(x_max)].copy())
rect_list.append(rect)
return paint, roi_list, rect_list
'''
@brief 以txt形式显示数据
@param data: 数据集
@return None
'''
def ShowDataTxt(data):
print("----------------------------------------------------------")
for i in range(28):
for j in range(28):
print(0 if data[0][i * 28 + j] == 255 else 1, end='')
print('\n')
print("----------------------------------------------------------")
if __name__ == "__main__":
# 加载
patch_sklearn()
model_path = "./SVC_C1_enhance.pkl"
if os.path.exists(model_path):
print("Model Exist, Load Form Serialization")
model = LoadSvcModel(model_path)
else:
print("Model Do Not Exist, Train")
# 训练
model = TrainSvc(1, False)
# 保存
SaveSvcModel(model, model_path)
# 测试
paint, nums, rects = FindNumbers(cv2.imread("test_final.jpg"))
predict_nums = []
for img in nums:
data = PreProcessThinFont(img, show=False)
# ShowDataTxt(data)
predict_nums.append(model.predict(data)[0])
for i in range(len(predict_nums)):
cv2.putText(paint,str(predict_nums[i]), (rects[i][0], rects[i][1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
cv2.imshow("show", paint)
cv2.waitKey(0)
给出几个识别后的效果:
OpenCV + sklearnSVM 实现手写数字分割和识别的更多相关文章
- 手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)
@ 目录 前言 一.OpenCV DNN模块 1.OpenCV DNN简介 2.LabVIEW中DNN模块函数 二.TensorFlow pb文件的生成和调用 1.TensorFlow2 Keras模 ...
- OpenCV+TensorFlow图片手写数字识别(附源码)
初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...
- opencv实现KNN手写数字的识别
人工智能是当下很热门的话题,手写识别是一个典型的应用.为了进一步了解这个领域,我阅读了大量的论文,并借助opencv完成了对28x28的数字图片(预处理后的二值图像)的识别任务. 预处理一张图片: 首 ...
- 手写数字0-9的识别代码(SVM支持向量机)
帮一个贴吧的朋友改的一段代码,源代码来自<机器学习实战> 原代码的功能是识别0和9两个数字 经过改动之后可以识别0~9,并且将分类器的产生和测试部分分开来写,免得每次测试数据都要重新生成分 ...
- 学习OpenCV——SVM 手写数字检测
转自http://blog.csdn.net/firefight/article/details/6452188 是MNIST手写数字图片库:http://code.google.com/p/supp ...
- 基于opencv的手写数字识别(MFC,HOG,SVM)
参考了秋风细雨的文章:http://blog.csdn.net/candyforever/article/details/8564746 花了点时间编写出了程序,先看看效果吧. 识别效果大概都能正确. ...
- OpenCV手写数字字符识别(基于k近邻算法)
摘要 本程序主要参照论文,<基于OpenCV的脱机手写字符识别技术>实现了,对于手写阿拉伯数字的识别工作.识别工作分为三大步骤:预处理,特征提取,分类识别.预处理过程主要找到图像的ROI部 ...
- 基于opencv的手写数字字符识别
摘要 本程序主要参照论文,<基于OpenCV的脱机手写字符识别技术>实现了,对于手写阿拉伯数字的识别工作.识别工作分为三大步骤:预处理,特征提取,分类识别.预处理过程主要找到图像的ROI部 ...
- 基于OpenCV的KNN算法实现手写数字识别
基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...
- 在opencv3中实现机器学习算法之:利用最近邻算法(knn)实现手写数字分类
手写数字digits分类,这可是深度学习算法的入门练习.而且还有专门的手写数字MINIST库.opencv提供了一张手写数字图片给我们,先来看看 这是一张密密麻麻的手写数字图:图片大小为1000*20 ...
随机推荐
- WPF 界面打不开提示 System.ArithmeticException Overflow or underflow in the arithmetic operation 异常
本文告诉大家如何解决界面打不开,抛出 System.ArithmeticException: Overflow or underflow in the arithmetic operation 异常的 ...
- JS代码优化小技巧
下面介绍一种JS代码优化的一个小技巧,通过动态加载引入js外部文件来提高网页加载速度 [基本优化] 将所有需要的<script>标签都放在</body>之前,确保脚本执行之前完 ...
- Linux 根文件系统的移植(从入门到精通)
一.简介 提到操作系统的安装,还得从大学的时候说起,刚入学的时,朋友的系统本崩了,跑去电脑城换个系统花了40大洋,震惊了贫穷的我.好像发现了商机,果断开始了折腾自己的电脑,然后用朋友的电脑进行测试,由 ...
- 实验8 #第8章 Verilog有限状态机设计-1 #Verilog #Quartus #modelsim
8-1 流水灯控制器 1. 实验要求:采用有限状态机设计彩灯控制器,控制LED灯实现预想的演示花型. 2. 实验内容: (1)功能:设计彩灯控制器,要求控制18个LED灯实现如下的演示花型: 从两边往 ...
- 习题8 #第8章 Verilog有限状态机设计-4 #Verilog #Quartus #modelsim
4. 用状态机设计交通灯控制器,设计要求:A路和B路,每路都有红.黄.绿三种灯,持续时间为:红灯45s,黄灯5s,绿灯40秒. A路和B路灯的状态转换是: (1) A红,B绿(持续时间40s): (2 ...
- .Net 8.0 下的新RPC,IceRPC之使用Dev Containers进行 .NET QUIC 精简开发
作者引言 很高兴啊,我们来到了IceRPC之使用Dev Containers进行 .NET QUIC 精简开发,主要是一篇指引,如何使用开发容器做为开发环境,进行开发IceRPC,可适用于任务应用的开 ...
- vue3.0 yarn启动项目
linux 系统 在root账号下 yarn install yarn run serve 启动服务 ctrl+c //暂停服务 yarn build 打包服务 在公共目录里添加配置文件 优点:这样就 ...
- golang cron定时任务简单实现
目录 星号(*) 斜线(/) 逗号(,) 连字符 (-) 问好 (?) 常用cron举例 使用说明 golang 实现定时服务很简单,只需要简单几步代码便可以完成,不需要配置繁琐的服务器,直接在代码中 ...
- WEB服务与NGINX(11)-NGINX状态页
nginx状态页 nginx的状态页功能用于输出nginx的基本状态信息,基于ngx_http_stub_status_module模块实现. 默认情况下不生成此模块,应使用--with-http_s ...
- PVT:特征金字塔在Vision Transormer的首次应用,又快又好 | ICCV 2021
论文设计了用于密集预测任务的纯Transformer主干网络PVT,包含渐进收缩的特征金字塔结构和spatial-reduction attention层,能够在有限的计算资源和内存资源下获得高分辨率 ...