前言

本文与前文对手写数字识别分类基本类似的,同样图像作为输入,类别作为输出。这里不同的是,不仅仅是使用简单的卷积神经网络加上全连接层的模型。卷积神经网络大火以来,发展出来许多经典的卷积神经网络模型,包括VGG、ResNet、AlexNet等等。下面将针对CIFAR-10数据集,对图像进行分类。

1、CIFAR-10数据集、Reader创建

CIFAR-10数据集分为5个batch的训练集和1个batch的测试集,每个batch包含10,000张图片。每张图像尺寸为32*32的RGB图像,且包含有标签。一共有10个标签:airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck十个类别。

我在CIFAR-10网站中下载的是[CIFAR-10 python version](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)。数据集完成后,解压得到上述六个文件。上述六个文件都是字典文件,使用cPickle模块即可读入。字典中‘data’需要重新定义维度为1000*32*32*3,维度分别代表[N H W C],即10,000张32*32尺寸的三通道(RGB)图像,再经过转换成为paddlepaddle读取的[N C H W ]维度形式;而字典‘labels’为10000个标签。如此一来,可以建立读取CIFAR-10的reader(与官方例程不同),如下:

def reader_creator(ROOT,istrain=True,cycle=False):
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename,'rb') as f:
datadict = Pickle.load(f)
X = datadict['data']
Y = datadict['labels']
""" (N C H W) transpose to (N H W C) """
X = X.reshape(10000,3,32,32).transpose(0,2,3,1).astype('float')
Y = np.array(Y)
return X,Y
def reader():
while True:
if istrain:
for b in range(1,6):
f = os.path.join(ROOT,'data_batch_%d'%(b))
X,Y = load_CIFAR_batch(f)
length = X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
else:
f = os.path.join(ROOT,'test_batch')
X,Y = load_CIFAR_batch(f)
length = X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
return reader

2、VGG网络

VGG网络采用“减小卷积核大小,增加卷积核数量”的思想改造而成,这里直接采用paddlepaddle例程中的VGG网络了,值得提醒的是paddlepaddle中直接有函数img_conv_group提供卷积、池化、dropout一组操作,所以根据VGG的模型,前面卷积层可以划分为5组,然后再经过3层的全连接层得到结果。

PaddlePaddle例程中根据上图D网络,加入dorpout:

def vgg_bn_drop(input):
def conv_block(ipt, num_filter, groups, dropouts):
return fluid.nets.img_conv_group(
input=ipt,
#一组的卷积层的卷积核总数,组成list[num_filter num_filter ...]
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
#每组卷积层各层的droput概率
conv_batchnorm_drop_rate=dropouts,
pool_size=2,
pool_stride=2,
pool_type='max') conv1 = conv_block(input, 64, 2, [0.3, 0]) #[0.3 0]即为第一组两层的dorpout概率,下同
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1 = fluid.layers.fc(input=drop, size=512, act=None) bn = fluid.layers.batch_norm(input=fc1, act='relu') drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2 = fluid.layers.fc(input=drop2, size=512, act=None) predict = fluid.layers.fc(input=fc2, size=10, act='softmax')
return predict

3、训练

训练程序与上一节例程一样,同样是选取交叉熵作为损失函数,不多累赘讲述。

def train_network():
predict = inference_network()
label = fluid.layers.data(name='label',shape=[1],dtype='int64')
cost = fluid.layers.cross_entropy(input=predict,label=label)
avg_cost = fluid.layers.mean(cost)
accuracy = fluid.layers.accuracy(input=predict,label=label)
return [avg_cost,accuracy] def optimizer_program():
return fluid.optimizer.Adam(learning_rate=0.001) def train(data_path,save_path):
BATCH_SIZE = 128
EPOCH_NUM = 2
train_reader = paddle.batch(
paddle.reader.shuffle(reader_creator(data_path),buf_size=50000),
batch_size = BATCH_SIZE)
test_reader = paddle.batch(
reader_creator(data_path,False),
batch_size=BATCH_SIZE)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
print("\nPass %d, Epoch %d, Cost %f, Acc %f" %
(event.step, event.epoch, event.metrics[0],
event.metrics[1]))
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, fluid.EndEpochEvent):
avg_cost, accuracy = trainer.test(
reader=test_reader, feed_order=['image', 'label'])
print('\nTest with Pass {0}, Loss {1:2.2}, Acc {2:2.2}'.format(
event.epoch, avg_cost, accuracy))
if save_path is not None:
trainer.save_params(save_path)
place = fluid.CUDAPlace(0)
trainer = fluid.Trainer(
train_func=train_network, optimizer_func=optimizer_program, place=place)
trainer.train(
reader=train_reader,
num_epochs=EPOCH_NUM,
event_handler=event_handler,
feed_order=['image', 'label'])

4、测试接口

测试接口也类似,需要特别注意的是图像维度要改为[N C H W]的顺序!

def infer(params_dir):
place = fluid.CUDAPlace(0)
inferencer = fluid.Inferencer(
infer_func=inference_network, param_path=params_dir, place=place)
# Prepare testing data.
from PIL import Image
import numpy as np
import os def load_image(file):
im = Image.open(file)
im = im.resize((32, 32), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
"""transpose [H W C] to [C H W]"""
im = im.transpose((2, 0, 1))
im = im / 255.0 # Add one dimension, [N C H W] N=1
im = np.expand_dims(im, axis=0)
return im
cur_dir = os.path.dirname(os.path.realpath(__file__))
img = load_image(cur_dir + '/dog.png')
# inference
results = inferencer.infer({'image': img})
print(results)
lab = np.argsort(results) # probs and lab are the results of one batch data
print("infer results: ", cifar_classes[lab[0][0][-1]])

5、运行结果

由于笔者没有GPU服务器,所以只迭代了50次,已经用了8个多小时,但是准确率只有15.6%,测试集方面准确率有17%,效果不理想,用于验证的结果也是错的!

Pass , Epoch , Cost 2.261115, Acc 0.156250
.........................................................................................
Test with Pass , Loss 2.2, Acc 0.17 Classify the cifar10 images...
[array([[0.05997971, 0.13485196, 0.096842 , 0.09973737, 0.11053724,
0.08180068, 0.13847008, 0.08627985, 0.06851784, 0.12298328]],
dtype=float32)]
infer results: frog

结语

网络比较深,且数据集比较大,训练时间比较长,普通笔记本上面的GT840M聊以胜无吧。

本文代码:02_cifar

参考:book/03.image_classification/

【PaddlePaddle系列】CIFAR-10图像分类的更多相关文章

  1. 【深度学习系列】用PaddlePaddle和Tensorflow进行图像分类

    上个月发布了四篇文章,主要讲了深度学习中的"hello world"----mnist图像识别,以及卷积神经网络的原理详解,包括基本原理.自己手写CNN和paddlepaddle的 ...

  2. ABP(现代ASP.NET样板开发框架)系列之10、ABP领域层——实体

    点这里进入ABP系列文章总目录 基于DDD的现代ASP.NET开发框架--ABP系列之10.ABP领域层——实体 ABP是“ASP.NET Boilerplate Project (ASP.NET样板 ...

  3. JVM基础系列第10讲:垃圾回收的几种类型

    我们经常会听到许多垃圾回收的术语,例如:Minor GC.Major GC.Young GC.Old GC.Full GC.Stop-The-World 等.但这些 GC 术语到底指的是什么,它们之间 ...

  4. Mysql高手系列 - 第10篇:常用的几十个函数详解,收藏慢慢看

    这是Mysql系列第10篇. 环境:mysql5.7.25,cmd命令中进行演示. MySQL 数值型函数 函数名称 作 用 abs 求绝对值 sqrt 求二次方根 mod 求余数 ceil 和 ce ...

  5. java高并发系列 - 第10天:线程安全和synchronized关键字

    这是并发系列第10篇文章. 什么是线程安全? 当多个线程去访问同一个类(对象或方法)的时候,该类都能表现出正常的行为(与自己预想的结果一致),那我们就可以所这个类是线程安全的. 看一段代码: pack ...

  6. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  7. ShoneSharp语言(S#)的设计和使用介绍系列(10)— 富家子弟“语句“不炫富

    ShoneSharp语言(S#)的设计和使用介绍 系列(10)— 富家子弟“语句“不炫富 作者:Shone 声明:原创文章欢迎转载,但请注明出处,https://www.cnblogs.com/Sho ...

  8. RabbitMQ 入门系列:10、扩展内容:延时队列:延时队列插件及其有限的适用场景(系列大结局)。

    系列目录 RabbitMQ 入门系列:1.MQ的应用场景的选择与RabbitMQ安装. RabbitMQ 入门系列:2.基础含义:链接.通道.队列.交换机. RabbitMQ 入门系列:3.基础含义: ...

  9. 深度学习与计算机视觉系列(2)_图像分类与KNN

    作者: 寒小阳 &&龙心尘 时间:2015年11月. 出处: http://blog.csdn.net/han_xiaoyang/article/details/49949535 ht ...

随机推荐

  1. 将IP地址转化为整数

    $ip = 'IP地址';echo $intip = sprintf('%u',ip2long($ip)); //转换为无符号整型echo long2ip($intip);//将整型转换为ip

  2. Creating a Simple Web Service and Client with JAX-WS

    Creating a Simple Web Service and Client with JAX-WS 发布服务 package cn.zno.service.impl; import javax. ...

  3. (转) MVC 中 @help 用法

    ASP.NET MVC 3支持一项名为“Razor”的新视图引擎选项(除了继续支持/加强现有的.aspx视图引擎外).当编写一个视图模板时,Razor将所需的字符和击键数减少到最小,并保证一个快速.通 ...

  4. C++中的数组问题

    C++中的数组问题 1. 数组赋值与初始化 (1)直接初始化: ]={,,}: (2)遍历访问初始化: ;i< ;i++) //直接读入,或者用别的数组,以及别的(i+1)等. (3)内存操作函 ...

  5. 菜鸟——springboot+mybatis+maven

    网上找了很多资料,学习如何搭建springboot,由于刚刚接触springboot,不是很熟练,通过参考网上别人搭建的例子,自己也搭建了一个简单的springboot+mybaits+maven 网 ...

  6. NLTK之WordNet 接口

    WordNet是面向语义的英语词典,类似于传统字典.它是NLTK语料库的一部分,可以被这样调用: 更简洁的写法: 1.单词 查看一个单词的同义词集用synsets(); 它有一个参数pos,可以指定查 ...

  7. mdadm详细使用手册

    1. 文档信息 当前版本 1.2 创建人 朱荣泽 创建时间 2011.01.07 修改历史 版本号 时间 内容 1.0 2011.01.07 创建<mdadm详细使用手册>1.0文档 1. ...

  8. SQL SERVER存储过程中使用事务与捕获异常

    https://www.douban.com/note/559596669/ 格式类似于 CREATE PROCEDURE YourProcedure ASBEGIN    SET NOCOUNT O ...

  9. 曲演杂坛--使用TRY CATCH应该注意的一个小细节

    群里一个朋友遇到一个TRY CATCH的小问题,测试后发现是自己从来没有考虑的情况,写篇blog加深下印象 --============================================ ...

  10. 浏览器报ScriptResource.axd异常

    新拷贝的一份管理后台代码,部署在另一台服务器上时,查看浏览器调控制台,发现有几个红色报错.这些错误导致网站的部分功能无法使用. 主要错误有: 1.“Sys”未定义 2.asp.net ajax 客户端 ...