写在前面

​ 文本分类是nlp中一个非常重要的任务,也是非常适合入坑nlp的第一个完整项目。虽然文本分类看似简单,但里面的门道好多好多,博主水平有限,只能将平时用到的方法和trick在此做个记录和分享,希望各位看官都能有所收获。并且尽可能提供给出简洁,清晰的代码实现。

​ 本文主要讨论文本分类中处理样本不均衡和提升模型鲁棒性的trick,也是最近面试总结的一部分(面试好痛苦/(ㄒoㄒ)/~~,但还是得淦)。文章内容是根据平时阅读论文,知乎,公众号和实践得到的,如有表述不够清楚、详尽的地方可参考文末的原作者链接。

缓解样本不均衡

假如我们要实现一个新闻正负面判断的文本二分类器,负面新闻的样本比例较少,可能2W条新闻有100条甚至更少的样本属于正例。这种现象就是样本不均衡,因为样本会呈现一个长尾分布,头部的标签包含了大量的样本,而尾部的标签拥有很少的样本,就像下面这张图片中表现的那样出现一个长长的尾巴,所以这种现场也称为长尾现象。

样本不均衡会带来很多问题。模型训练的本质是最小化损失函数,当某个类别的样本数量非常庞大,损失函数的值大部分被样本数量较大的类别所影响,导致的结果就是模型分类会倾向于样本量较大的类别。咱们拿上面文本分类的例子来说明,现在有2W条用户搜索的样本,其中100条是负面新闻,即正样本,那么当模型全部将样本预测为负例,也能得到99.5%的准确率,但这个模型跟盲猜也没区别,没什么用,我们的目的是找到让模型能够正确的区分正例和负例。

模型层面解决样本不均衡

加入Focal Loss学习难学样本,具体原理可以参考苏神的文章

Focal Loss pytorch代码实现

class FocalLoss(nn.Module):
"""Multi-class Focal loss implementation"""
def __init__(self, gamma=2, weight=None, reduction='mean', ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
log_pt = torch.log_softmax(input, dim=1)
pt = torch.exp(log_pt)
log_pt = (1 - pt) ** self.gamma * log_pt
loss = torch.nn.functional.nll_loss(log_pt, target, self.weight, reduction=self.reduction, ignore_index=self.ignore_index)
return loss

代码链接:

py版本:https://github.com/PouringRain/blog_code/blob/main/nlp/focal_loss.py

喜欢的话,给萌新的github仓库一颗小星星哦……^ _^

数据层面解决样本不均衡

现在我们遇到样本不均衡的问题,假如我们的正样本只有100条,而负样本可能有1W条。如果不采取任何策略,那么我们就是使用这1.01W条样本去训练模型。从数据层面解决样本不均衡的问题核心是通过人为控制正负样本的比例,分成欠采样和过采样两种。

3.1 欠采样

欠采样的基本做法是这样的,现在我们的正负样本比例为1:100。如果我们想让正负样本比例不超过1:10,那么模型训练的时候数量比较少的正样本也就是100条全部使用,而负样本随机挑选1000条,这样通过人为的方式我们把样本的正负比例强行控制在了1:10。这种方式存在一个问题,为了强行控制样本比例我们生生的舍去了那9000条负样本,这对于模型来说是莫大的损失。

相比于简单的对负样本随机采样的欠采样方法,实际工作中我们会使用迭代预分类的方式来采样负样本。具体流程如下图所示:

首先我们会使用全部的正样本和从负例候选集中随机采样一部分负样本(这里假如是100条)去训练第一轮分类器;然后用第一轮分类器去预测负例候选集剩余的9900条数据,把9900条负例中预测为正例的样本(也就是预测错误的样本)再随机采样100条和第一轮训练的数据放到一起去训练第二轮分类器;同样的方法用第二轮分类器去预测负例候选集剩余的9800条数据,直到训练的第N轮分类器可以全部识别负例候选集,这就是使用迭代预分类的方式进行欠采样。

相比于随机欠采样来说迭代预分类的欠采样方式能最大限度的利用负样本中差异性较大的负样本,从而在控制正负样本比例的基础上采样出了最有代表意义的负样本。

欠采样的方式整体来说或多或少的会损失一些样本,对于那些需要控制样本量级的场景下比较合适。如果没有严格控制样本量级的要求那么下面的过采样可能会更加适合你。

3.2 过采样

过采样和上面的欠采样比较类似,都是人工干预控制样本的比例,不同的是过采样不会损失样本。还拿上面的例子,现在有正样本100条,负样本1W条,最简单的过采样方式是我们会使用全部的负样本1W条,但是为了维持正负样本比例,我们会从正样本中有放回的重复采样,直到获取了1000条正样本,也就是说有些正样本可能会被重复采样到,这样就能保持1:10的正负样本比例了。这是最简单的过采样方式,这种方式可能会存在严重的过拟合。

实际的场景中会通过样本增强的技术来增加正样本。

提升模型鲁棒性

提升模型鲁棒性的方法有很多,其中对抗训练、知识蒸馏、防止模型过拟合和多模型融合是常见的稳定提升方式,let's see see!

对抗训练

对抗训练是一种能有效提高模型鲁棒性和泛化能力的训练手段,其基本原理是通过在原始输入上增加对抗扰动,得到对抗样本,再利用对抗样本进行训练,从而提高模型的表现。

由于自然语言文本是离散的,一般会把对抗扰动添加到嵌入层上。为了最大化对抗样本的扰动能力,利用梯度上升的方式生成对抗样本。为了避免扰动过大,将梯度做了归一化处理。

$$

{g} = -\bigtriangledown_ {\mathcal{L}}(y_i|{v}; {\hat{\theta}} ) \

{v}^* = {v}+ \epsilon{g} / |{g}|_2

$$

其中,v为嵌入向量。实际训练过程中,我们在训练完一个batch的原始输入数据时,保存当前batch对输入词向量的梯度,得到对抗样本后,再使用对抗样本进行对抗训练。

对抗训练pytorch代码实现

class FGM():
def __init__(self, model):
self.model = model
self.backup = {} def attack(self, epsilon=1., emb_name='emb'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = epsilon * param.grad / norm
param.data.add_(r_at) def restore(self, emb_name='emb'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}

训练中加入几行代码

# 初始化
fgm = FGM(model)
for batch_input, batch_label in data:
# 正常训练
loss = model(batch_input, batch_label)
loss.backward()
# 对抗训练
fgm.attack() # 修改embedding
# optimizer.zero_grad() # 梯度累加,不累加去掉注释
loss_sum = model(batch_input, batch_label)
loss_sum.backward() # 累加对抗训练的梯度
fgm.restore() # 恢复Embedding的参数 optimizer.step()
optimizer.zero_grad()

代码链接:

py版本:https://github.com/PouringRain/blog_code/blob/main/nlp/at.py

喜欢的话,给萌新的github仓库一颗小星星哦……^ _^

知识蒸馏

与对抗训练类似,知识蒸馏也是一种常用的提高模型泛化能力的训练方法。

知识蒸馏这个概念最早由Hinton在2015年提出。一开始,知识蒸馏通往往应用在模型压缩方面,利用训练好的复杂模型(teacher model)输出作为监督信号去训练另一个简单模型(student model),从而将teacher学习到的知识迁移到student。Tommaso在18年提出,如果student和teacher的模型完全相同,蒸馏后则会对模型的表现有一定程度上的提升。

防止模型过拟合

正则化

L1和L2正则化

L1正则化可以得到稀疏解,L2正则化可以得到平滑解,原因参考(https://blog.csdn.net/f156207495/article/details/82794151)。

Dropout

dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。dropout为什么能防止过拟合,可以通过以下几个方面来解释:

  1. 它强迫一个神经单元,和随机挑选出来的其他神经单元共同工作,达到好的效果。消除减弱了神经元节点间的联合适应性,增强了泛化能力。
  2. 类似于bagging的集成效果
  3. 对于每一个dropout后的网络,进行训练时,相当于做了Data Augmentation,因为,总可以找到一个样本,使得在原始的网络上也能达到dropout单元后的效果。 比如,对于某一层,dropout一些单元后,形成的结果是(1.5,0,2.5,0,1,2,0),其中0是被drop的单元,那么总能找到一个样本,使得结果也是如此。这样,每一次dropout其实都相当于增加了样本。

dropout在测试时,并不会随机丢弃神经元,而是使用全部所有的神经元,同时,所有的权重值都乘上1-p,p代表的是随机失活率。

数据增强

数据增强即需要得到更多的符合要求的数据,即和已有的数据是独立同分布的,或者近似独立同分布的。一般有以下方法:

1)从数据源头采集更多数据

2)复制原有数据并加上随机噪声

3)重采样

4)根据当前数据集估计数据分布参数,使用该分布产生更多数据等

Early stopping

在模型对训练数据集迭代收敛之前停止迭代来防止过拟合。因为在初始化网络的时候一般都是初始为较小的权值,训练时间越长,部分网络权值可能越大。如果我们在合适时间停止训练,就可以将网络的能力限制在一定范围内。

交叉验证

交叉验证的基本思想就是将原始数据(dataset)进行分组,一部分做为训练集来训练模型,另一部分做为测试集来评价模型。我们常用的交叉验证方法有简单交叉验证、S折交叉验证和留一交叉验证。

Batch Normalization

一种非常有用的正则化方法,可以让大型的卷积网络训练速度加快很多倍,同时收敛后分类的准确率也可以大幅度的提高。BN在训练某层时,会对每一个mini-batch数据进行标准化(normalization)处理,使输出规范到N(0,1)的正态分布,减少了Internal convariate shift(内部神经元分布的改变),传统的深度神经网络在训练是,每一层的输入的分布都在改变,因此训练困难,只能选择用一个很小的学习速率,但是每一层用了BN后,可以有效的解决这个问题,学习速率可以增大很多倍。

选择合适的网络结构

通过减少网络层数、神经元个数、全连接层数等降低网络容量

多模型融合

Baggging &Boosting,将弱分类器融合之后形成一个强分类器,而且融合之后的效果会比最好的弱分类器更好,三个臭皮匠顶一个诸葛亮。

参考资料

  1. 文本分类中的样本不均衡问题
  2. 功守道:NLP 中的对抗训练 + PyTorch 实现
  3. 欠拟合,过拟合及如何防止过拟合
  4. 知识蒸馏论文
  5. 苏神focal loss理解

Bert文本分类实践(三):处理样本不均衡和提升模型鲁棒性trick的更多相关文章

  1. Bert文本分类实践(二):魔改Bert,融合TextCNN的新思路

    写在前面 ​ 文本分类是nlp中一个非常重要的任务,也是非常适合入坑nlp的第一个完整项目.虽然文本分类看似简单,但里面的门道好多好多,博主水平有限,只能将平时用到的方法和trick在此做个记录和分享 ...

  2. Bert文本分类实践(一):实现一个简单的分类模型

    写在前面 文本分类是nlp中一个非常重要的任务,也是非常适合入坑nlp的第一个完整项目.虽然文本分类看似简单,但里面的门道好多好多,作者水平有限,只能将平时用到的方法和trick在此做个记录和分享,希 ...

  3. Pytorch——BERT 预训练模型及文本分类

    BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...

  4. 中文文本分类之TextRNN

    RNN模型由于具有短期记忆功能,因此天然就比较适合处理自然语言等序列问题,尤其是引入门控机制后,能够解决长期依赖问题,捕获输入样本之间的长距离联系.本文的模型是堆叠两层的LSTM和GRU模型,模型的结 ...

  5. 利用RNN进行中文文本分类(数据集是复旦中文语料)

    利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) 1.训练词向量 数据预处理参考利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) ,现在我们有了分词 ...

  6. 用迁移学习创造的通用语言模型ULMFiT,达到了文本分类的最佳水平

    https://www.jqr.com/article/000225 这篇文章的目的是帮助新手和外行人更好地了解我们新论文,我们的论文展示了如何用更少的数据自动将文本分类,同时精确度还比原来的方法高. ...

  7. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  8. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  9. 文本分类实战(十)—— BERT 预训练模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

随机推荐

  1. HTTP系列之:HTTP中的cookies

    目录 简介 cookies的作用 创建cookies cookies的生存时间 cookies的权限控制 第三方cookies 总结 简介 如果小伙伴最近有访问国外的一些标准网站的话,可能经常会弹出一 ...

  2. vue 封装 axios 和 各类的请求,以及引入 .vue 文件中使用

    //src 底下建立 api 文件夹 // api 文件夹下建立 request,js 文件,文件内容复制下面这段代码即可   /**  * ajax请求配置  */ import axios fro ...

  3. 学习Linux tar 命令:最简单也最困难

    摘要:在本文中,您将学习与tar 命令一起使用的最常用标志.如何创建和提取 tar 存档以及如何创建和提取 gzip 压缩的 tar 存档. 本文分享自华为云社区<Linux 中的 Tar 命令 ...

  4. WebService学习总结(二)--使用JDK开发WebService

    一.WebService的开发方法 使用java的WebService时可以使用一下两种开发手段 使用jdk开发(1.6及以上版本) 使用CXF框架开发(工作中) 二.使用JDK开发WebServic ...

  5. 了解mysql concat()函数

    concat(arg1,arg2,....):将形参对应字段的值组合成一个字符串 假设:现在有一张学生表(test_user) 将这三个字段组合成一个字符串作为第四个字段 select test_us ...

  6. go语言游戏服务端开发(三)——服务机制

    五邑隐侠,本名关健昌,12年游戏生涯. 本教程以Go语言为例.   P2P网络为服务进程间.服务进程与客户端间通信提供了便利,在这个基础上可以搭建服务. 在服务层,通信包可以通过定义协议号来确定该包怎 ...

  7. HDU 6170 Two strings( DP+字符串匹配)

    http://acm.hdu.edu.cn/showproblem.php?pid=6170 题目大意: 给出两个字符串s1和s2(长度小于等于2500). s1是一个正常的包含大小写字母的字符串,s ...

  8. 学习PHP中统计扩展函数的使用

    做统计相关系统的朋友一定都会学习过什么正态分布.方差.标准差之类的概念,在 PHP 中,也有相应的扩展函数是专门为这些统计相关的功能所开发的.我们今天要学习的 stats 扩展函数库就是这类操作函数. ...

  9. Git 访问慢 解决办法

    1. 查询Git最快的IP 通过 https://www.ipaddress.com/ 这个网站来获取当前github最新的ip分别获取以下两个域名的IP地址: 可以在访问git网站使用F12查询哪个 ...

  10. I/O流中的字节流

    今天总结一下Java中重要的知识点I/O流,今天主要学习了字节流(自己的理解) 什么是I/O:我们把这种数据的传输,可以看做是一种数据的流动,按照流动的方向,以内存为基准,分为输入input和输出ou ...