最近实在是有点忙,没啥时间写博客了。趁着周末水一文,把最近用 huggingface transformers 训练文本分类模型时遇到的一个小问题说下。

背景

之前只闻 transformers 超厉害超好用,但是没有实际用过。之前涉及到 bert 类模型都是直接手写或是在别人的基础上修改。但这次由于某些原因,需要快速训练一个简单的文本分类模型。其实这种场景应该挺多的,例如简单的 POC 或是临时测试某些模型。

我的需求很简单:用我们自己的数据集,快速训练一个文本分类模型,验证想法。

我觉得如此简单的一个需求,应该有模板代码。但实际去搜的时候发现,官方文档什么时候变得这么多这么庞大了?还多了个 Trainer API?瞬间让我想起了 Pytorch Lightning 那个坑人的同名 API。但可能是时间原因,找了一圈没找到适用于自定义数据集的代码,都是用的官方、预定义的数据集。

所以弄完后,我决定简单写一个文章,来说下这原本应该极其容易解决的事情。

数据

假设我们数据的格式如下:

0 第一个句子
1 第二个句子
0 第三个句子

即每一行都是 label sentence 的格式,中间空格分隔。并且我们已将数据集分成了 train.txtval.txt

代码

加载数据集

首先使用 datasets 加载数据集:

from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})

加载后的 dataset 是一个 DatasetDict 对象:

DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 3
})
test: Dataset({
features: ['text'],
num_rows: 3
})
})

类似 tf.data ,此后我们需要对其进行 map ,对每一个句子进行 tokenize、padding、batch、shuffle:

def tokenize_function(examples):
labels = []
texts = []
for example in examples['text']:
split = example.split(' ', maxsplit=1)
labels.append(int(split[0]))
texts.append(split[1])
tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
tokenized['labels'] = labels
return tokenized tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

根据数据集格式不同,我们可以在 tokenize_function 中随意自定义处理过程,以得到 text 和 labels。注意 batch_sizemax_length 也是在此处指定。处理完我们便得到了可以输入给模型的训练集和测试集。

训练

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()

你可以根据情况修改训练 batchsize per_device_train_batch_size

完整代码

完整代码见 GitHub

END

使用 Transformers 在你自己的数据集上训练文本分类模型的更多相关文章

  1. (2) 用DPM(Deformable Part Model,voc-release4.01)算法在INRIA数据集上训练自己的人体检測模型

    步骤一,首先要使voc-release4.01目标检測部分的代码在windows系统下跑起来: 參考在window下执行DPM(deformable part models) -(检測demo部分) ...

  2. 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)

    基于深度学习和迁移学习的识花实践(转)   深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...

  3. [PocketFlow]解决TensorFLow在COCO数据集上训练挂起无输出的bug

    1. 引言 因项目要求,需要在PocketFlow中添加一套PeleeNet-SSD和COCO的API,具体为在datasets文件夹下添加coco_dataset.py, 在nets下添加pelee ...

  4. CaffeExample 在CIFAR-10数据集上训练与测试

    本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...

  5. 第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)

    前面已经介绍了几种经典的目标检测算法,光学习理论不实践的效果并不大,这里我们使用谷歌的开源框架来实现目标检测.至于为什么不去自己实现呢?主要是因为自己实现比较麻烦,而且调参比较麻烦,我们直接利用别人的 ...

  6. NVIDIA GPUs上深度学习推荐模型的优化

    NVIDIA GPUs上深度学习推荐模型的优化 Optimizing the Deep Learning Recommendation Model on NVIDIA GPUs 推荐系统帮助人在成倍增 ...

  7. Microsoft Dynamics CRM 2011 当您在 大型数据集上执行 RetrieveMultiple 查询很慢的解决方法

    症状 当您在 Microsoft Dynamics CRM 2011 年大型数据集上执行 RetrieveMultiple 查询时,您会比较慢. 原因 发生此问题是因为大型数据集缓存 Retrieve ...

  8. 在Titanic数据集上应用AdaBoost元算法

    一.AdaBoost 元算法的基本原理 AdaBoost是adaptive boosting的缩写,就是自适应boosting.元算法是对于其他算法进行组合的一种方式. 而boosting是在从原始数 ...

  9. TersorflowTutorial_MNIST数据集上简单CNN实现

    MNIST数据集上简单CNN实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 源代码请点击下方链接欢迎加星 Tesorflow实现基于MNI ...

  10. BP算法在minist数据集上的简单实现

    BP算法在minist上的简单实现 数据:http://yann.lecun.com/exdb/mnist/ 参考:blog,blog2,blog3,tensorflow 推导:http://www. ...

随机推荐

  1. GSLB工作原理

    参考文档: http://chongit.github.io/2015/04/15/GSLB%E6%A6%82%E8%A6%81%E5%92%8C%E5%AE%9E%E7%8E%B0%E5%8E%9F ...

  2. 阿里云经典网络Debian 11 启动非常慢

    有一台阿里云经典网络的实例.系统太老了,重装了Debian 11,但是启动非常慢,要5分钟才能开机,简直离谱. root@AliYun:~# systemd-analyze blame 5min 3. ...

  3. git入门123

    一.新手上路 最重要的4招: 1. 初始化本地仓库 git init 或者 git clone 远程仓库地址 2.添加改动文件 git add 改动的文件名或者目录 偷懒的话可以直接 git add ...

  4. stream-分组两次

    Map<String, Map<String, List<AmazonBalanceCustom>>> amazonBalanceMap = amazonBalan ...

  5. SecurityRandom随机数生成

    package com.netauth.utils; import java.security.SecureRandom; /** * * <p> * SecureRandom随机数生成工 ...

  6. Qframework UIKit

    用QFramework的UIKit 功能很容易实现UI模块的MVC功能,但MVC模式构造起来还是会有些繁琐, 两个相互直接的UIElement 之间的一些数据传输和调用都要用Msg通过UIPanel ...

  7. centos删除安装vsftpd

    准备工作 1.centos 卸载vsftpd 删除原有的vsftpd(卸载前先关闭 vsftpd: systemctl stop vsftpd)[root@localhost ~]# rpm -aq ...

  8. [转]Selenium私房菜系列1 -- Selenium简介

    一.Selenium是什么? Selenium是ThroughtWorks公司一个强大的开源Web功能测试工具系列,本系列现在主要包括以下4款: 1.Selenium Core:支持DHTML的测试案 ...

  9. 20191323王予涵sort

    sort 任务 用man sort 查看sort的帮助文档 sort常用选项有哪些,都有什么功能?提交相关使用的截图 如果让你编写sort,你怎么实现?写出伪代码和相关的函数或系统调用 一.查看帮助文 ...

  10. JRAT远控

    JRAT java写的一款远控 可以控制mac os  Windows linux 系统 https://share.weiyun.com/21b56e3e5cab4b3f145d7c2330d107 ...