【目标检测】用Fast R-CNN训练自己的数据集超详细全过程
目录:
一、环境准备
二、训练步骤
三、测试过程
四、计算mAP
寒假在家下载了Fast R-CNN的源码进行学习,于是使用自己的数据集对这个算法进行实验,下面介绍训练的全过程。
一、环境准备
我这里的环境是win10系统,pycharm + python3.7
二、训练过程
1、下载Fast R-CNN源码
https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3
2、安装扩展包
下载的源码中有一个 requirements.txt文件,列出了需要安装的扩展包名字。可以在cmd中直接运行以下代码:
pip install -r requirements.txt
或者使用pip命令一个一个安装,所需要的扩展包有:cython、opencv-python、easydict、Pillow、matplotlib、scipy。
如果使用conda管理,按conda的下载方式也可以。
3、下载并添加预训练模型
源码中预训练模型使用的是VGG16,VGG16模型可点开下方链接直接下载:
http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
下载的模型名字应该是vgg_16.ckpt,重命名为vgg16.ckpt 后,把模型保存在data\imagenet_weights\文件夹下。
也可以使用其他的模型替代VGG16,其他模型在下方链接中下载:
https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models
4、修改训练参数
打开源码的lib\config文件夹下的config.py文件,修改其中一些重要参数,如:
(1)network参数
该参数定义了预训练模型网络,源码中默认使用了vgg16模型,我们使用vgg16就不需修改,如果在上一步中使用其他模型就要修改。
(2)learning_rate
这个参数是学习率,如果设定太大就可能产生振荡,如果设定太小就会使收敛速度很慢。所以我们可以先默认为源码的0.001进行实验,后期再取0.01或0.0001等多次实验,找到运行后的相对最优值。
(3)batch_size
该参数表示梯度下降时数据批量大小,一般可以取16、32、64、128、256等。我个人的理解是,batch_size设定越大,训练时梯度下降的速率更快,也具有更高的方向准确度,但更加消耗内存;batch_size设定越小,虽然节省内存,但训练的速率比较慢,收敛效果也可能不是很好。所以在内存允许的情况下,尽量设定大一些。
(4)max_iters
max_iters参数表示训练最大迭代的步数。源码中是40000,我实验了4000和40000的步数,发现后来的测试结果中mAP值相差不大,以后会再继续研究。这个参数可以先按照源码的40000进行(要跑好几天。。。)
(5)snapshot_iterations
这个参数表示间隔多少迭代次数生成一次结果模型。
(6)roi_bg_threshold_low 和 roi_bg_threshold_high
这个参数表示在背景中被设定为ROI(感兴趣区域,region of interest)的阈值。如果后面出现Exception: image invalid, skipping 这样的报错,将roi_bg_threshold_low参数修改为0.0会解决问题。
5、替换数据集
源码中的VOCDevkit2007文件夹存放的是数据集,我们将自己的数据集按照文件夹结构替换存放在VOCDevkit2007中。Annotations存放的是标签的XML文件,JPEGImages存放的是自己的数据集所有图片,ImageSets\Main文件夹下保存的是test.txt、train.txt、trainval.txt、validation.txt,分别是测试集、训练集、训练验证集、验证集的标签文件名号。可以按照下图的结构制作自己的数据集~
考虑到源码中没有数据集划分程序,这里把划分代码贴出来,替换成自己的各个文件路径后直接运行就可以自动生成所需的txt文件啦。
# 数据集划分集类
import os
from sklearn.model_selection import train_test_split image_path = r'F:/111/data/VOCDevkit2007/VOC2007/JPEGImages'
image_list = os.listdir(image_path)
names = [] for i in image_list:
names.append(i.split('.')[0]) # 获取图片名
trainval,test = train_test_split(names,test_size=0.5,shuffle=446) # shuffle()中是图片总数目
validation,train = train_test_split(trainval,test_size=0.5,shuffle=446) with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/trainval.txt','w') as f:
for i in trainval:
f.write(i+'\n')
with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/test.txt','w') as f:
for i in test:
f.write(i+'\n')
with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/validation.txt','w') as f:
for i in validation:
f.write(i+'\n')
with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/train.txt','w') as f:
for i in train:
f.write(i+'\n') print('完成!')
6、生成所需文件
在cmd中进入 ./data/coco/PythonAPI文件夹路径,分别运行下面两条命令:
python setup.py build_ext --inplace
python setup.py build_ext install
之后,在cmd中进入 ./lib/utils文件夹路径,运行下面一条命令:
python setup.py build_ext --inplace
这样,就生成训练需要的文件啦。
7、修改目标类别
打开lib/datasets目录中的pascal_voc.py文件,第34行self._classes表示目标检测的类别,将其修改为自己数据集的类别。注意不能修改 “_background_”,它表示图片的背景。
8、删除缓存文件
打开源码中data/cache目录,删掉上一次训练生成的.pkl缓存文件。打开default/voc_2007_trainval/default目录,删掉上次训练生成的模型。
注意以后每次训练都要删掉上述两个文件夹中的缓存文件和模型,不删会报错的哦。
9、运行train.py文件
做好上面所有步骤之后,就可以运行train.py文件进行训练啦。每次生成的模型都会保存在default/voc_2007_trainval/default目录下。
三、测试过程
1、添加训练模型
新建Faster-RCNN-TensorFlow-Python3-master/output/vgg16/voc_2007_trainval/default目录。把训练生成的模型(default/voc_2007_trainval/default目录下的四个文件)复制到新建目录下,并重命名为如下图:
2、修改demo.py文件
(1)修改目标类别
修改demo.py文件中line32,CLASSES中的类别要修改为之前步骤中相同的类别。注意 “_background_”不要修改。
(2)修改网络模型
找到demo.py文件中line35、line36,将其修改为如下图所示:
(3)修改预训练模型
找到demo.py文件中line104,将其修改为'vgg16',如下图:
(4)修改测试图片
找到demo.py文件中的line148,改为自己测试用的几张图片名称。注意和data/demo目录下存放的测试图片名字一致。
3、运行demo.py文件
做好上述修改后,就可以运行demo.py文件啦,能够对新的测试图片进行目标检测。
四、计算mAP
mAP(mean Average Precision), 即各类别AP的平均值,反映出一个目标检测模型性能的总体精确度。
1、修改pascal_voc.py文件
打开pascal_voc.py文件,找到line189,将"filename"内容修改为下图:
2、修改demo.py文件
打开demo.py文件,找到line31,添加两个模块:
# 添加这两个import
from lib.utils.test import test_net
from lib.datasets.factory import get_imdb
添加后如图所示:
然后,找到最后一行plt.show(),在它上面添加两行代码:
# 添加这两行代码
imdb = get_imdb("voc_2007_trainval")
test_net(sess, net, imdb, 'default')
添加后如图所示:
3、运行demo.py文件
新建data/VOCDevkit2007/results/VOC2007/Main目录,然后运行demo.py文件,等待运行结束就能看到mAP指标的计算结果啦!贴出我自己模型的计算结果吧!
这次内容就分享到这里了,希望与各位老师和小伙伴们交流学习~
【目标检测】用Fast R-CNN训练自己的数据集超详细全过程的更多相关文章
- 【目标检测实战】目标检测实战之一--手把手教你LMDB格式数据集制作!
文章目录 1 目标检测简介 2 lmdb数据制作 2.1 VOC数据制作 2.2 lmdb文件生成 lmdb格式的数据是在使用caffe进行目标检测或分类时,使用的一种数据格式.这里我主要以目标检测为 ...
- 目标检测(三) Fast R-CNN
引言 之前学习了 R-CNN 和 SPPNet,这里做一下回顾和补充. 问题 R-CNN 需要对输入进行resize变换,在对大量 ROI 进行特征提取时,需要进行卷积计算,而且由于 ROI 存在重复 ...
- Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)
目录 1. 准备数据集 1.1 MNIST数据集获取: 1.2 程序部分 2. 设计网络结构 2.1 网络设计 2.2 程序部分 3. 迭代训练 4. 测试集预测部分 5. 全部代码 1. 准备数据集 ...
- R及R Studio下载安装教程(超详细)
R 语言是为数学研究工作者设计的一种数学编程语言,主要用于统计分析.绘图.数据挖掘. 如果你是一个计算机程序的初学者并且急切地想了解计算机的通用编程,R 语言不是一个很理想的选择,可以选择 Pytho ...
- 论文笔记:目标检测算法(R-CNN,Fast R-CNN,Faster R-CNN,FPN,YOLOv1-v3)
R-CNN(Region-based CNN) motivation:之前的视觉任务大多数考虑使用SIFT和HOG特征,而近年来CNN和ImageNet的出现使得图像分类问题取得重大突破,那么这方面的 ...
- 第三十节,目标检测算法之Fast R-CNN算法详解
Girshick, Ross. “Fast r-cnn.” Proceedings of the IEEE International Conference on Computer Vision. 2 ...
- 目标检测(三)Fast R-CNN
作者:Ross Girshick 该论文提出的目标检测算法Fast Region-based Convolutional Network(Fast R-CNN)能够single-stage训练,并且可 ...
- 第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)
前面已经介绍了几种经典的目标检测算法,光学习理论不实践的效果并不大,这里我们使用谷歌的开源框架来实现目标检测.至于为什么不去自己实现呢?主要是因为自己实现比较麻烦,而且调参比较麻烦,我们直接利用别人的 ...
- 目标检测算法(1)目标检测中的问题描述和R-CNN算法
目标检测(object detection)是计算机视觉中非常具有挑战性的一项工作,一方面它是其他很多后续视觉任务的基础,另一方面目标检测不仅需要预测区域,还要进行分类,因此问题更加复杂.最近的5年使 ...
随机推荐
- 数据分析中常用的Excel函数
数据分析中excel是一个常见且基础的数据分析工具,要想掌握好它,学会使用其中的常用函数是一个绕不过去的坎.从网上搜集的资料来说,基本上确定了数据分析中Excel的常用函数有以下这六类 数学函数:SU ...
- hbase伪分布式环境的搭建
一,实验环境: 1, ubuntu server 16.04 2, jdk,1.8 3, hadoop 2.7.4 伪分布式环境或者集群模式 4, hbase-1.2.6.tar.gz 二,环境的搭建 ...
- JDK-7新特性,更优雅的关闭流-java try-with-resource语句使用
前言 公司最近代码质量整改,需要对大方法进行调整,降低到50行以下,对方法的深度进行降低,然后有些文件涉及到流操作,很多try/catch/finally语句,导致行数超出规范值,使用这个语法可以很好 ...
- redis-避免生产环境使用keys命令
redis作为内存数据库, 有着很高的性能, Redis能读的速度是110000次/s, 写的速度是81000次/s; 除了进行持久化操作时, redis采用的是单线程架构, 所以如果我们在开发中不恰 ...
- 设计模式(十五)——命令模式(Spring框架的JdbcTemplate源码分析)
1 智能生活项目需求 看一个具体的需求 1) 我们买了一套智能家电,有照明灯.风扇.冰箱.洗衣机,我们只要在手机上安装 app 就可以控制对这些家电工作. 2) 这些智能家电来自不同的厂家,我们不想针 ...
- C - Door Man(欧拉回路_格式控制)
现在你是一个豪宅的管家,因为你有个粗心的主人,所以需要你来帮忙管理,输入会告诉你现在一共有多少个房间,然后会告诉你从哪个房间出发,你的任务就是从出发的房间通过各个房间之间的通道,来把所有的门都关上,然 ...
- Codeforces Round #646 (Div. 2) A. Odd Selection(数学)
题目链接:https://codeforces.com/contest/1363/problem/A 题意 判断是否能从 $n$ 个数中选 $x$ 个数加起来和为奇数. 题解 首先 $n$ 个数中至少 ...
- OpenStack Train版-15.创建并挂载存储卷
1.创建并挂载存储卷 创建一个1GB的卷 source ~/demo-openrc openstack volume create --size 1 volume1 很短的时间后,卷状态应该从crea ...
- python之字符串strip、rstrip、lstrip的方法
1.描述 strip():用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列 rstrip():用于移除字符串右边指定的字符(默认为空格或换行符)或字符序列 lstrip():用于移除字符串 ...
- Vmware 15.5 ubuntu 12.04.5-desktop-i386.iso insmod后死机
就是makefile没有问题,在其他同学的相同环境下也没有问题,但是在我的虚拟机里就会死机,复制了其他同学的虚拟机过来也会死机,所以猜想是VMware的问题. 于是下载了Virtual box,然后安 ...