使用fastai完成图像分类
1. 使用现有数据集进行分类
图像数据为Oxford-IIIT Pet Dataset(12类猫和25类狗,共37类),这里仅使用原始图片集images.tar.gz
数据准备
import numpy as np
from fastai.vision import *
from fastai.metrics import error_rate
path_img = 'data/pets/images'
bs = 64 #batch size
fnames = get_image_files(path_img) #get filenames(absolute path) from path_img
pat = re.compile(r'/([^/]+)_d+.jpg$') #get labels from filenames(e.g., 'american_bulldog' from 'data/pets/images/american_bulldog_20.jpg')
### ImageDataBunch
### 使用正则表达式pat从图像文件名fnames中提取标签,并和图像对应起来
### ds_tfms: 图像转换(翻转、旋转、裁剪、放大等),用于图像数据增强(data augmentation)
### size: 最终图像尺寸, bs: batch size, valid_pct: train/valid split
### normalize: 使用提供的均值和标准差(每个通道对应一个均值和标准差)对图像数据进行归一化
np.random.seed(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs, valid_pct=0.2).normalize(imagenet_stats)
data.show_batch(rows=3, figsize=(7,6)) #grab a batch and display 3x3 images
模型搭建和训练
使用Resnet34进行迁移学习,首先通过lr_find确定最大学习率,再通过fit_one_cycle(1-Cycle style)进行训练
lr_find: 在前面几次的迭代中将学习率从一个很小的值逐渐增加,选择损失函数(train loss)处于下降趋势之中并且距离损失停止下降的拐点有一定距离的点做为模型的最大学习率max_lr
fit_one_cycle: 共分为两个阶段,在第一阶段学习率从max_lr/div_factor线性增长到max_lr,momentum线性地从moms[0]降到moms[1];第二阶段学习率以余弦形式从max_lr降为0,momentum也同样按余弦形式从moms[1]增长到moms[0]。第一阶段的迭代次数占总迭代次数的比例为pct_start
学习率和momentum: , , , 其中是要更新的参数,G为梯度, 为学习率, 为momentum
### Use Resnet34 to classify images
learn = create_cnn(data, models.resnet34, metrics=error_rate)
print(learn.model) #model summary
learn.lr_find()
learn.recorder.plot() #由左上图可以看出max_lr可选择函数fit_one_cycle的默认值0.003
learn.fit_one_cycle(4, max_lr=slice(0.003), div_factor=25.0, moms=(0.95, 0.85), pct_start=0.3) #4 epochs
learn.recorder.plot_lr(show_moms=True) #中上图(学习率)和右上图(momentum), x轴表示迭代次数
learn.save('stage-1') #save model
### Unfreeze all the model layers and keep training
learn.unfreeze()
learn.lr_find()
learn.recorder.plot() #左下图
### 由左下图可以看出max_lr可选择1e-6, 但是模型的不同层可以设置不同的学习率加速训练
### 模型的前面几层的学习率设置为max_lr, 后面几层的学习率可以适当增加(例如可以设置成比上一个fit_one_cycle的学习率小一个量级)
### slice(1e-6,1e-4)表示模型每层的学习率由1e-6逐渐增加过渡到1e-4
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4), div_factor=25.0, moms=(0.95, 0.85), pct_start=0.3) #2 epochs
learn.recorder.plot_lr(show_moms=True) #中下图(模型最后一层的学习率)和右下图(momentum)
可视化
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12,12), dpi=60) #confusion matrix
print(interp.most_confused(min_val=2)) #从大到小列出混淆矩阵中非对角线的最大的几个元素
2. 从谷歌图片下载数据并进行分类
获得图片链接
打开谷歌图片,输入想要下载的图像类别,页面上出现的图片即为可下载的图片
打开JavaScript Console(Windows/Linux:Ctrl+Shift+J, Mac:Cmd+Opt+J),运行下面的命令获取图片链接
大专栏 使用fastai完成图像分类 class="nx">urls = Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou);
window.open('data:text/csv;charset=utf-8,' + escape(urls.join('n')));
分别搜索teddy bears、 black bears、 grizzly bears, 将下载的保存链接的文件分别命名为urls_teddys.txt、 urls_black.txt、 urls_grizzly.txt
下载图片
import numpy as np
from fastai.vision import *
from fastai.metrics import error_rate
### 建立目录并下载图片
path = Path('data/bears')
folders = ['teddys', 'black', 'grizzly']
files = 'urls_teddys.txt', 'urls_black.txt', 'urls_grizzly.txt'
for i,folder in enumerate(folders):
dest = path/folder
dest.mkdir(parents=True, exist_ok=True)
download_images(files[i], dest, max_pics=200)
print(path.ls())
### 删除不能被打开的图片
for folder in folders:
verify_images(path/folder, delete=True, max_size=500)
训练模型
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=get_transforms(), size=224, bs=64, num_workers=4).normalize(imagenet_stats)
print(data.classes)
learn = create_cnn(data, models.resnet34, metrics=error_rate)
learn.lr_find()
learn.recorder.plot() #左图
learn.fit_one_cycle(4)
learn.save('stage-1')
learn.unfreeze()
learn.lr_find()
learn.recorder.plot() #右图
learn.fit_one_cycle(2, max_lr=slice(3e-5,3e-4)) #若数据量较小,该步不一定有正效果
learn.save('stage-2')
learn.load('stage-1') #选择stage-1
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
根据训练好的模型去除错误图片
模型预测效果不好不一定是因为模型本身的问题,还可能是由于图片自身的问题(例如下载了错误的图片,图片标签有误),需要进行检查和处理
from fastai.widgets import *
### ds: 训练图片集, idxs: 具有最大损失的训练图片索引
ds, idxs = DatasetFormatter().from_toplosses(learn, n_imgs=200) #选出前200个具有最大损失的训练图片
ImageCleaner(ds, idxs, path) #手动处理,处理好的文件被存入path/cleaned.csv(该文件仅包含经过处理后的训练图片集,不包含验证图片)
可根据具体情况对处理之后的数据重新进行训练
保存模型并预测
learn.export() #将模型存入learn.path/export.pkl
learn = load_learner(path) #从path中读取模型
img = open_image(path/'black'/'00000021.jpg') #以训练集中的一个图片为例
pred_class,pred_idx,outputs = learn.predict(img) #预测图片
print(pred_class) #输出类别
print(outputs) #输出每个类的概率
使用fastai完成图像分类的更多相关文章
- Atitit 图像处理--图像分类 模式识别 肤色检测识别原理 与attilax的实践总结
Atitit 图像处理--图像分类 模式识别 肤色检测识别原理 与attilax的实践总结 1.1. 五中滤镜的分别效果..1 1.2. 基于肤色的图片分类1 1.3. 性能提升2 1.4. --co ...
- 【转】[caffe]深度学习之图像分类模型AlexNet解读
[caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097 本文章已收录于: ...
- 基于Pre-Train的CNN模型的图像分类实验
基于Pre-Train的CNN模型的图像分类实验 MatConvNet工具包提供了好几个在imageNet数据库上训练好的CNN模型,可以利用这个训练好的模型提取图像的特征.本文就利用其中的 “im ...
- [caffe]深度学习之图像分类模型VGG解读
一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...
- 如何在程序中调用Caffe做图像分类
Caffe是目前深度学习比较优秀好用的一个开源库,采样c++和CUDA实现,具有速度快,模型定义方便等优点.学习了几天过后,发现也有一个不方便的地方,就是在我的程序中调用Caffe做图像分类没有直接的 ...
- [caffe]深度学习之图像分类模型AlexNet解读
在imagenet上的图像分类challenge上Alex提出的alexnet网络结构模型赢得了2012届的冠军.要研究CNN类型DL网络模型在图像分类上的应用,就逃不开研究alexnet.这是CNN ...
- 【深度学习系列】用PaddlePaddle和Tensorflow进行图像分类
上个月发布了四篇文章,主要讲了深度学习中的"hello world"----mnist图像识别,以及卷积神经网络的原理详解,包括基本原理.自己手写CNN和paddlepaddle的 ...
- 【Keras】从两个实际任务掌握图像分类
我们一般用深度学习做图片分类的入门教材都是MNIST或者CIFAR-10,因为数据都是别人准备好的,有的甚至是一个函数就把所有数据都load进来了,所以跑起来都很简单,但是跑完了,好像自己还没掌握图片 ...
- OpenCV探索之路(二十八):Bag of Features(BoF)图像分类实践
在深度学习在图像识别任务上大放异彩之前,词袋模型Bag of Features一直是各类比赛的首选方法.首先我们先来回顾一下PASCAL VOC竞赛历年来的最好成绩来介绍物体分类算法的发展. 从上表我 ...
随机推荐
- os发展史
01. 操作系统的发展历史 1.1 Unix 1965 年之前的时候,电脑并不像现在一样普遍,它可不是一般人能碰的起的,除非是军事或者学院的研究机构,而且当时大型主机至多能提供30台终端(30个键盘. ...
- java使用forEach填充字典值
// 填充字典值 Vector vector = vectorMapper.selectByPrimaryKey(id); VectorModel vectorModel = new VectorMo ...
- 如何正确理解SQL关联子查询
一.基本逻辑 对于外部查询返回的每一行数据,内部查询都要执行一次.在关联子查询中是信息流是双向的.外部查询的每行数据传递一个值给子查询,然后子查询为每一行数据执行一次并返回它的记录.然后,外部查询根据 ...
- c#学习笔记02——接口
本身并不实现功能,但提供一种模板定义,为从它继承类或结构提供了一种定义的规范 有了接口,程序员可以把程序定义的更积极啊清晰和条理化 理解接口 接口支持多继承:抽象类不能实现多继承 接口只能定义抽象规则 ...
- 复杂json解析方式[GsonFormat]
针对开发人员来讲,善于用工具,事半功倍. 干货: 1.IntelliJ IDEA 通过GsonFormat插件将JSONObject格式的String 解析成实体 插件地址:https://plugi ...
- Tidb go mac 上开发环境搭建
1.安装golang 运行环境 2.安装lite ide 工具 3.安装dep 包管理工具 4.安装delve debuger 调试工具 我用的是mac hight sierra 10.13 版, 会 ...
- 11)PHP,单选框和复选框的post提交方式处理
就是一个表单中会有input的checkbox形式,那么怎么处理,就有了问题,一般采用二维数组来处理 代码展示: <!DOCTYPE html PUBLIC "-//W3C//DTD ...
- 在拖放文件的同时检测shift键的状态
老板要给原来文件拖放的功能加个扩展分类,于是想在文件拖放时判断shift键的状态来区分. 一般通过keydown和keyup来判断按下与否,但这都是需要控件事件触发,而在拖放的时候是没法触发key事件 ...
- 许家印67亿买下FF恒大是要雪中送炭吗?
从大环境来看,当下新能源汽车已经是备受投资者青睐的领域.据不完全统计,当下国内已经有300余家电动汽车企业.而蔚来.小鹏.威马等动辄都融资上百亿元,显现出火爆的发展趋势.甚至就连董明珠董大姐也有着自己 ...
- 42)PHP,mysqli函数功能总结
fetch----------------一个一个的取值,这个注意 fetch_array(),fetch_assoc(),fetch_object(),这三个方法的使用请看手册 请注意是FETCH, ...