The Impact of Imbalanced Training Data for Convolutional Neural Networks

Paulina Hensman and David Masko

摘要

本论文从实验的角度调研了训练数据的不均衡性对采用CNN解决图像分类问题的性能影响。CIFAR-10数据集包含10个不同类别的60000个图像,用来构建不同类间分布的数据集。例如,一些训练集中包含一个类别的图像数目与其他类别的图像数目比例失衡。用这些训练集分别来训练一个CNN,度量其得到的网络的分类性能。实验结果表明:不均衡的训练数据对CNN的整体性能可能具有严重的负面影响,而均衡的训练数据能产生最好的性能。Oversampling技术在不均衡训练数据上可以将性能提升到均衡数据上的水平,所以它是一种对抗不均衡性的重要技术。

概况

在过去的几年里,由于在诸如机器视觉、语音识别及自然语言处理等几个领域获得重大突破,人工神经网络(Artificial Neural Networks)得到广泛的关注。没有任何先验与假设,这些网络采用统计的方法可以近似大量数据中潜在的函数与模式。DNN(Deep Neural Networks)以及CNN(Convolutional Neural Networks)两类特殊的神经网络是常用来解决复杂问题的现代方法。不利的一面是,为了学习到一个令人满意的神经网络,通常需要大量的数据。对于有监督的学习,还需要大量的标注数据。众所周知,标注数据通常是依赖于人工标注获得,因此获取困难。有一些标注好的图像数据是公开可用的,这些数据为研究与应用人员提供了标准资源,便于比较不同分类方法的,用以证明在该领域取得了一些进展。经验上讲,平衡数据集优于非平衡数据集,然而在真实的情况下,可用的数据集通常是不均衡的。如何处理不均衡数据是机器学习中一个很大的挑战。一些方法能够减轻不均衡数据带来的影响,但是并没有系统的研究结果表明DNN与CNN在标准数据集上如何受不均衡数据的影响。

本文重点研究由于训练数据的类别不均衡带来的CNN分类性能的损失。由此进一步探索:什么类型的分布对性能有损?Oversampling在提升性能方面起多大的作用?具体来讲主要包含以下四个问题:

(1)训练数据中均衡的类别分别对CNN的重要性有多大?

(2)CNN的性能如何受训练数据中不同类别分布的影响?

(3)通过调整训练数据的类别分布能否改善CNN的性能?

(4)有什么可行的方法来实现这种调整?

图像分类是判断给定的图像属于哪一类别的过程,直观来讲,就是图像包含了哪些物体。图像分类主要有两种形式:图像级别标注与对象级别标注。图像级别标注是一个二值变量,用来指示一个对象是否出现在图像上,例如,图像上是否有一只猫。对象级别的标注是具体到对象在图像中出现的位置。例如,螺丝刀中心位于(20,25),宽为50像素,高为30像素。本文关注图像级别的标注。

不均衡数据是指机器学习算法在训练的过程中所采用的数据在不同类别上的分布是不均衡的。由于采用均衡数据学习的算法性能远优于不均衡数据的,所以不均衡数据给分类问题带来了挑战。实际中可用的数据通常是不均衡的。然而,大多数的学习算法假设训练数据是均衡,也同样假设未标注的数据也是类间均衡的。若训练数据的分布于测试集并不相同,这类算法通常会降低性能。进一步来讲,多数算法的目标在于最小化整体的错误率,这会导致训练数据中的小众类由于训练数据少而性能不佳。当小众类非常重要时,这种影响是完全负面的。例如,罕见疾病的诊断。不均衡数据已经得到了广泛的关注,有许多有效的方法可以解决这个问题。

已有提升不均衡数据上的学习性能的方法大致分为三类: (1)sampling techniques;(2)Cost sensitive techniques;(3)One-class learning。采样技术改变原始的数据集,从而创建均衡数据集。简单的采样技术包括oversampling(从小众类中重复采样直至均衡),undersampling(移除over-represented类别的数据)与其他采样技术。然而有研究表明将oversampling与undersampling结合可能是应对极端不均衡数据的方式。

  • budget-sensitive progressive sampling algorithm

训练数据数目n

该采样策略依赖于几个假设:(1)与获取训练数据相比,学习算法的执行代价是可以忽略的,因为在该采样算法中学习算法需要运行多次。当训练数据获取代价高时,这一点是成立的。(2)假设每个类别的获取代价是相同的。这样的话预算数目n与训练实例数是一致的。这个假设大多时候是成立的,但也有例外。如,先前提及的电话数据,获取普通消费者和商业电话的代价是一样的,但是欺诈电话的识别代价是高昂的。

  • combination of cost-sensitive technique and undersampling
实验设置

数据集:选用CIFAR-10,包含10个不同的类别,数据集较小,仅包含60000左右的images(不选择ImageNet的原因),便于做批量的实验,但又不至于任务太简单(如MNIST)

数据集划分:5000 images per category for training and 1000 for testing

类别分布:选择11个不同的类别分布,分别考察其分类性能,每种分布其实都是具有代表性的,毕竟10个类别的分布均衡,是很难量化的一个指标,所以这里只是举出几个典型的例子来说明。在本文中,并没有给出class imbalance的一个明确的量化的定义。

网络结构:use caffe to create and train a CNN

参数设置:3 convolutional layers and 10 output nodes, trained with learning rate 0.001 for 8 epochs + learning rate 0.0001 for 2 epochs, momentum set to 0.9, weight decay to 0.004

测试数据:mean results of three runs

评价指标:the percentage of correct answers for each class,然后再做平均。

实验结果

(1)数据越均衡,分类性能越好

(2)oversampling可以给imbalance 数据带来性能的提升,数据越不均衡提升越明显。

阅读笔记 The Impact of Imbalanced Training Data for Convolutional Neural Networks [DegreeProject2015] 数据分析型的更多相关文章

  1. 论文笔记(Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration)

    这是CVPR 2019的一篇oral. 预备知识点:Geometric median 几何中位数 \begin{equation}\underset{y \in \mathbb{R}^{n}}{\ar ...

  2. 【论文阅读】Learning Dual Convolutional Neural Networks for Low-Level Vision

    论文阅读([CVPR2018]Jinshan Pan - Learning Dual Convolutional Neural Networks for Low-Level Vision) 本文针对低 ...

  3. [CVPR2015] Is object localization for free? – Weakly-supervised learning with convolutional neural networks论文笔记

    p.p1 { margin: 0.0px 0.0px 0.0px 0.0px; font: 13.0px "Helvetica Neue"; color: #323333 } p. ...

  4. 论文笔记之:Spatially Supervised Recurrent Convolutional Neural Networks for Visual Object Tracking

    Spatially Supervised Recurrent Convolutional Neural Networks for Visual Object Tracking  arXiv Paper ...

  5. 论文笔记之:Learning Multi-Domain Convolutional Neural Networks for Visual Tracking

    Learning Multi-Domain Convolutional Neural Networks for Visual Tracking CVPR 2016 本文提出了一种新的CNN 框架来处理 ...

  6. [论文阅读] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (MobileNet)

    论文地址:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 本文提出的模型叫Mobi ...

  7. 深度学习笔记 (一) 卷积神经网络基础 (Foundation of Convolutional Neural Networks)

    一.卷积 卷积神经网络(Convolutional Neural Networks)是一种在空间上共享参数的神经网络.使用数层卷积,而不是数层的矩阵相乘.在图像的处理过程中,每一张图片都可以看成一张“ ...

  8. Bag of Tricks for Image Classification with Convolutional Neural Networks笔记

    以下内容摘自<Bag of Tricks for Image Classification with Convolutional Neural Networks>. 1 高效训练 1.1 ...

  9. 论文笔记之《Event Extraction via Dynamic Multi-Pooling Convolutional Neural Network》

    1. 文章内容概述 本人精读了事件抽取领域的经典论文<Event Extraction via Dynamic Multi-Pooling Convolutional Neural Networ ...

随机推荐

  1. TCP/UDP的接收包方式

    UDP udp不是流式的,每次接收一个包,长度不超过(65535-28,总包长65535字节,包头28字节).所以UDP方式下不需要填写任何参数直接调用 $client->recv() 即可.注 ...

  2. ios xmpp开发应用后台模式接收聊天信息处理方案

    ios xmpp开发应用后台模式接收聊天信息 最近在使用xmppframwork来实现一个聊天应用,碰到了一个问题,应用进入后台以后,就接收不到消息了: 怎么样才能使应用被切到后台时,应用中的网络连接 ...

  3. ASP.NET中进行消息处理(MSMQ) 三(转)

    在本文的前两篇文章里对MSMQ的相关知识点进行了介绍,很多阅读过这前两篇文章的朋友都曾问到过这样一些问题:  1.如何把MSMQ应用到实际的项目中去呢?  2.可不可以介绍一个实际的应用实例?  3. ...

  4. 算法库:boost安装配置

    前提是电脑上已经装有VS. 1. 下载boost_1_60_0.zip并解压到所需位置 2. 双击bootstrap.bat生成b2.exe(新版)和bjam.exe(老版) 3. 双击b2.exe或 ...

  5. Quartz.NET管理周期性任务

    Quartz.NET是一个开源的作业调度框架,非常适合在平时的工作中,定时轮询数据库同步,定时邮件通知,定时处理数据等. Quartz.NET允许开发人员根据时间间隔(或天)来调度作业.它实现了作业和 ...

  6. No.5__C#

    One month 今天是个有纪念意义的日子,2015-4-23.今天是实习的第一个月,算是成就达成吧.虽然,除去了周末六日和清明什么的,只剩下20多天了,但是,还是好开心 啊,毕竟是第一次参加工作, ...

  7. strong reference cycle in block

    However, because the reference is weak, the object that self points to could be deallocated while th ...

  8. JS适配问题。

    动画requestAnimFrame + cancelAnimationFrame window.requestAnimFrame = (function(){ return window.reque ...

  9. CSS基础篇

    写的不错,收藏 http://www.cnblogs.com/suoning/p/5625582.html

  10. nodejs生成UID(唯一标识符)——node-uuid模块

    unique identifier 惟一标识符        -->> uid 在项目开发中我们常需要给某些数据定义一个唯一标识符,便于寻找,关联. node-uuid模块很好的提供了这个 ...