相关文章:

基础知识介绍:

【一】ERNIE:飞桨开源开发套件,入门学习,看看行业顶尖持续学习语义理解框架,如何取得世界多个实战的SOTA效果?_汀、的博客-CSDN博客_ernie模型

百度飞桨:ERNIE 3.0 、通用信息抽取 UIE、paddleNLP的安装使用[一]_汀、的博客-CSDN博客_paddlenlp 安装


项目实战:

PaddleHub--飞桨预训练模型应用工具{风格迁移模型、词法分析情感分析、Fine-tune API微调}【一】_汀、的博客-CSDN博客

PaddleHub--{超参优化AutoDL Finetuner}【二】_汀、的博客-CSDN博客

PaddleHub实战篇{词法分析模型LAC、情感分类ERNIE Tiny}训练、部署【三】_汀、的博客-CSDN博客

PaddleHub实战篇{ERNIE实现文新闻本分类、ERNIE3.0 实现序列标注}【四】_汀、的博客-CSDN博客


通过前面几篇文章大家都有一定了解,下面直接上代码讲解

1.ERNIE实现文新闻本分类

用最新版本paddlenlp和paddle!

项目链接:ERNIE实现新闻文本分类 -

  1. !pip install --upgrade paddlenlp
  2. !pip install -U paddlehub
  3. #安装
  1. import paddlehub as hub
  2. import paddle
  3. model = hub.Module(name="ernie", task='seq-cls', num_classes=14) # 在多分类任务中,num_classes需要显式地指定类别数,此处根据数据集设置为14

hub.Module的参数用法如下:

  • name:模型名称,可以选择ernieernie_tinybert-base-cased, bert-base-chineseroberta-wwm-extroberta-wwm-ext-large等。
  • task:fine-tune任务。此处为seq-cls,表示文本分类任务。
  • num_classes:表示当前文本分类任务的类别数,根据具体使用的数据集确定,默认为2。

NOTE: 文本多分类的任务中,num_classes需要用户指定,具体的类别数根据选用的数据集确定,本教程中为14。

PaddleHub还提供BERT等模型可供选择, 当前支持文本分类任务的模型对应的加载示例如下:

模型名 PaddleHub Module
ERNIE, Chinese hub.Module(name='ernie')
ERNIE tiny, Chinese hub.Module(name='ernie_tiny')
ERNIE 2.0 Base, English hub.Module(name='ernie_v2_eng_base')
ERNIE 2.0 Large, English hub.Module(name='ernie_v2_eng_large')
BERT-Base, English Cased hub.Module(name='bert-base-cased')
BERT-Base, English Uncased hub.Module(name='bert-base-uncased')
BERT-Large, English Cased hub.Module(name='bert-large-cased')
BERT-Large, English Uncased hub.Module(name='bert-large-uncased')
BERT-Base, Multilingual Cased hub.Module(nane='bert-base-multilingual-cased')
BERT-Base, Multilingual Uncased hub.Module(nane='bert-base-multilingual-uncased')
BERT-Base, Chinese hub.Module(name='bert-base-chinese')
BERT-wwm, Chinese hub.Module(name='chinese-bert-wwm')
BERT-wwm-ext, Chinese hub.Module(name='chinese-bert-wwm-ext')
RoBERTa-wwm-ext, Chinese hub.Module(name='roberta-wwm-ext')
RoBERTa-wwm-ext-large, Chinese hub.Module(name='roberta-wwm-ext-large')
RBT3, Chinese hub.Module(name='rbt3')
RBTL3, Chinese hub.Module(name='rbtl3')
ELECTRA-Small, English hub.Module(name='electra-small')
ELECTRA-Base, English hub.Module(name='electra-base')
ELECTRA-Large, English hub.Module(name='electra-large')
ELECTRA-Base, Chinese hub.Module(name='chinese-electra-base')
ELECTRA-Small, Chinese hub.Module(name='chinese-electra-small')

通过以上的一行代码,model初始化为一个适用于文本分类任务的模型,为ERNIE的预训练模型后拼接上一个全连接网络(Full Connected

1.1: 加载自定义数据集

本示例数据集是由清华大学提供的新闻文本数据集THUCNews。THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。我们在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。为了快速展示如何使用PaddleHub完成文本分类任务,该示例数据集从THUCNews训练集中随机抽取了9000条文本数据集作为本示例的训练集,从验证集中14个类别每个类别随机抽取100条数据作为本示例的验证集,测试集抽取方式和验证集相同。

数据集:https://aistudio.baidu.com/aistudio/projectdetail/4153654

首先解压数据集。

  1. # 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
  2. # View dataset directory. This directory will be recovered automatically after resetting environment.
  3. %cd /home/aistudio/data/data16287/
  4. !tar -zxvf thu_news.tar.gz
  5. !ls -hl thu_news
  6. !head -n 3 thu_news/train.txt

加载自定义数据集,用户仅需要继承TextClassificationDataset类。 下面代码示例展示如何将自定义数据集加载进PaddleHub使用。

具体详情可参考 加载自定义数据集

  1. import os, io, csv
  2. from paddlehub.datasets.base_nlp_dataset import InputExample, TextClassificationDataset
  3. # 数据集存放位置
  4. DATA_DIR="/home/aistudio/data/data16287/thu_news"
  1. class ThuNews(TextClassificationDataset):
  2. def __init__(self, tokenizer, mode='train', max_seq_len=128):
  3. if mode == 'train':
  4. data_file = 'train.txt'
  5. elif mode == 'test':
  6. data_file = 'test.txt'
  7. else:
  8. data_file = 'valid.txt'
  9. super(ThuNews, self).__init__(
  10. base_path=DATA_DIR,
  11. data_file=data_file,
  12. tokenizer=tokenizer,
  13. max_seq_len=max_seq_len,
  14. mode=mode,
  15. is_file_with_header=True,
  16. label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚'])
  17. # 解析文本文件里的样本
  18. def _read_file(self, input_file, is_file_with_header: bool = False):
  19. if not os.path.exists(input_file):
  20. raise RuntimeError("The file {} is not found.".format(input_file))
  21. else:
  22. with io.open(input_file, "r", encoding="UTF-8") as f:
  23. reader = csv.reader(f, delimiter="\t", quotechar=None)
  24. examples = []
  25. seq_id = 0
  26. header = next(reader) if is_file_with_header else None
  27. for line in reader:
  28. example = InputExample(guid=seq_id, text_a=line[0], label=line[1])
  29. seq_id += 1
  30. examples.append(example)
  31. return examples
  32. train_dataset = ThuNews(model.get_tokenizer(), mode='train', max_seq_len=128)
  33. dev_dataset = ThuNews(model.get_tokenizer(), mode='dev', max_seq_len=128)
  34. test_dataset = ThuNews(model.get_tokenizer(), mode='test', max_seq_len=128)
  35. for e in train_dataset.examples[:3]:
  36. print(e)

NOTE: 最大序列长度max_seq_len是可以调整的参数,建议值128,根据任务文本长度不同可以调整该值,但最大不超过512。

1.2: 选择优化策略和运行配置

  1. optimizer = paddle.optimizer.Adam(learning_rate=5e-5, parameters=model.parameters()) # 优化器的选择和参数配置
  2. trainer = hub.Trainer(model, optimizer, checkpoint_dir='./ckpt', use_gpu=True) # fine-tune任务的执行者

优化策略

Paddle2.0-rc提供了多种优化器选择,如SGDAdamAdamax等,详细参见策略

在本教程中选择了Adam优化器,其的参数用法:

  • learning_rate: 全局学习率。默认为1e-3;
  • parameters: 待优化模型参数。

运行配置

Trainer 主要控制Fine-tune任务的训练,是任务的发起者,包含以下可控制的参数:

  • model: 被优化模型;
  • optimizer: 优化器选择;
  • use_gpu: 是否使用gpu训练;
  • use_vdl: 是否使用vdl可视化训练过程;
  • checkpoint_dir: 保存模型参数的地址;
  • compare_metrics: 保存最优模型的衡量指标;

1.3: 执行fine-tune并评估模型

  1. trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset, save_interval=1) # 配置训练参数,启动训练,并指定验证集
  1. [2022-01-19 15:02:20,462] [ TRAIN] - Epoch=3/3, Step=170/282 loss=0.0352 acc=0.9875 lr=0.000050 step/sec=4.33 | ETA 00:04:31
  2. [2022-01-19 15:02:22,777] [ TRAIN] - Epoch=3/3, Step=180/282 loss=0.1361 acc=0.9656 lr=0.000050 step/sec=4.32 | ETA 00:04:30
  3. [2022-01-19 15:02:25,096] [ TRAIN] - Epoch=3/3, Step=190/282 loss=0.0730 acc=0.9844 lr=0.000050 step/sec=4.31 | ETA 00:04:29
  4. [2022-01-19 15:02:27,415] [ TRAIN] - Epoch=3/3, Step=200/282 loss=0.0645 acc=0.9875 lr=0.000050 step/sec=4.31 | ETA 00:04:28
  5. [2022-01-19 15:02:29,729] [ TRAIN] - Epoch=3/3, Step=210/282 loss=0.0652 acc=0.9844 lr=0.000050 step/sec=4.32 | ETA 00:04:27
  6. [2022-01-19 15:02:32,048] [ TRAIN] - Epoch=3/3, Step=220/282 loss=0.1083 acc=0.9594 lr=0.000050 step/sec=4.31 | ETA 00:04:26
  7. [2022-01-19 15:02:34,362] [ TRAIN] - Epoch=3/3, Step=230/282 loss=0.1116 acc=0.9656 lr=0.000050 step/sec=4.32 | ETA 00:04:25
  8. [2022-01-19 15:02:36,673] [ TRAIN] - Epoch=3/3, Step=240/282 loss=0.1040 acc=0.9656 lr=0.000050 step/sec=4.33 | ETA 00:04:24
  9. [2022-01-19 15:02:38,976] [ TRAIN] - Epoch=3/3, Step=250/282 loss=0.0556 acc=0.9844 lr=0.000050 step/sec=4.34 | ETA 00:04:23
  10. [2022-01-19 15:02:41,289] [ TRAIN] - Epoch=3/3, Step=260/282 loss=0.0755 acc=0.9750 lr=0.000050 step/sec=4.32 | ETA 00:04:22
  11. [2022-01-19 15:02:43,607] [ TRAIN] - Epoch=3/3, Step=270/282 loss=0.1749 acc=0.9563 lr=0.000050 step/sec=4.31 | ETA 00:04:22
  12. [2022-01-19 15:02:45,918] [ TRAIN] - Epoch=3/3, Step=280/282 loss=0.1602 acc=0.9594 lr=0.000050 step/sec=4.33 | ETA 00:04:21
  1. result = trainer.evaluate(test_dataset, batch_size=32) # 在测试集上评估当前训练模型

1.4、使用模型进行预测

当Finetune完成后,我们加载训练后保存的最佳模型来进行预测,完整预测代码如下:

  1. # Data to be prdicted
  2. data = [
  3. # 房产
  4. ["昌平京基鹭府10月29日推别墅1200万套起享97折  新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。  京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。"],
  5. # 游戏
  6. ["尽管官方到今天也没有公布《使命召唤:现代战争2》的游戏详情,但《使命召唤:现代战争2》首部包含游戏画面的影片终于现身。虽然影片仅有短短不到20秒,但影片最后承诺大家将于美国时间5月24日NBA职业篮球东区决赛时将会揭露更多的游戏内容。  这部只有18秒的广告片闪现了9个镜头,能够辨识的场景有直升机飞向海岛军事工事,有飞机场争夺战,有潜艇和水下工兵,有冰上乘具,以及其他的一些镜头。整体来看《现代战争2》很大可能仍旧与俄罗斯有关。  片尾有一则预告:“May24th,EasternConferenceFinals”,这是什么?这是说当前美国NBA联赛东部总决赛的日期。原来这部视频是NBA季后赛奥兰多魔术对波士顿凯尔特人队时,TNT电视台播放的广告。"],
  7. # 体育
  8. ["罗马锋王竟公然挑战两大旗帜拉涅利的球队到底错在哪  记者张恺报道主场一球小胜副班长巴里无可吹捧,罗马占优也纯属正常,倒是托蒂罚失点球和前两号门将先后受伤(多尼以三号身份出场)更让人揪心。阵容规模扩大,反而表现不如上赛季,缺乏一流强队的色彩,这是所有球迷对罗马的印象。  拉涅利说:“去年我们带着嫉妒之心看国米,今年我们也有了和国米同等的超级阵容,许多教练都想有罗马的球员。阵容广了,寻找队内平衡就难了,某些时段球员的互相排斥和跟从前相比的落差都正常。有好的一面,也有不好的一面,所幸,我们一直在说一支伟大的罗马,必胜的信念和够级别的阵容,我们有了。”拉涅利的总结由近一阶段困扰罗马的队内摩擦、个别球员闹意见要走人而发,本赛季技术层面强化的罗马一直没有上赛季反扑的面貌,内部变化值得球迷关注。"],
  9. # 教育
  10. ["新总督致力提高加拿大公立教育质量  滑铁卢大学校长约翰斯顿先生于10月1日担任加拿大总督职务。约翰斯顿先生还曾任麦吉尔大学长,并曾在多伦多大学、女王大学和西安大略大学担任教学职位。  约翰斯顿先生在就职演说中表示,要将加拿大建设成为一个“聪明与关爱的国度”。为实现这一目标,他提出三个支柱:支持并关爱家庭、儿童;鼓励学习与创造;提倡慈善和志愿者精神。他尤其强调要关爱并尊重教师,并通过公立教育使每个人的才智得到充分发展。"]
  11. ]
  12. label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚']
  13. label_map = {
  14. idx: label_text for idx, label_text in enumerate(label_list)
  15. }
  16. model = hub.Module(
  17. name='ernie',
  18. task='seq-cls',
  19. load_checkpoint='./ckpt/best_model/model.pdparams',
  20. label_map=label_map)
  21. results = model.predict(data, max_seq_len=128, batch_size=1, use_gpu=True)
  22. for idx, text in enumerate(data):
  23. print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
  1. Data: 昌平京基鹭府1029日推别墅1200万套起享97折  新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计1029日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。  京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。 Lable: 房产
  2. Data: 尽管官方到今天也没有公布《使命召唤:现代战争2》的游戏详情,但《使命召唤:现代战争2》首部包含游戏画面的影片终于现身。虽然影片仅有短短不到20秒,但影片最后承诺大家将于美国时间524NBA职业篮球东区决赛时将会揭露更多的游戏内容。  这部只有18秒的广告片闪现了9个镜头,能够辨识的场景有直升机飞向海岛军事工事,有飞机场争夺战,有潜艇和水下工兵,有冰上乘具,以及其他的一些镜头。整体来看《现代战争2》很大可能仍旧与俄罗斯有关。  片尾有一则预告:“May24thEasternConferenceFinals”,这是什么?这是说当前美国NBA联赛东部总决赛的日期。原来这部视频是NBA季后赛奥兰多魔术对波士顿凯尔特人队时,TNT电视台播放的广告。 Lable: 游戏
  3. Data: 罗马锋王竟公然挑战两大旗帜拉涅利的球队到底错在哪  记者张恺报道主场一球小胜副班长巴里无可吹捧,罗马占优也纯属正常,倒是托蒂罚失点球和前两号门将先后受伤(多尼以三号身份出场)更让人揪心。阵容规模扩大,反而表现不如上赛季,缺乏一流强队的色彩,这是所有球迷对罗马的印象。  拉涅利说:“去年我们带着嫉妒之心看国米,今年我们也有了和国米同等的超级阵容,许多教练都想有罗马的球员。阵容广了,寻找队内平衡就难了,某些时段球员的互相排斥和跟从前相比的落差都正常。有好的一面,也有不好的一面,所幸,我们一直在说一支伟大的罗马,必胜的信念和够级别的阵容,我们有了。”拉涅利的总结由近一阶段困扰罗马的队内摩擦、个别球员闹意见要走人而发,本赛季技术层面强化的罗马一直没有上赛季反扑的面貌,内部变化值得球迷关注。 Lable: 体育
  4. Data: 新总督致力提高加拿大公立教育质量  滑铁卢大学校长约翰斯顿先生于101日担任加拿大总督职务。约翰斯顿先生还曾任麦吉尔大学长,并曾在多伦多大学、女王大学和西安大略大学担任教学职位。  约翰斯顿先生在就职演说中表示,要将加拿大建设成为一个“聪明与关爱的国度”。为实现这一目标,他提出三个支柱:支持并关爱家庭、儿童;鼓励学习与创造;提倡慈善和志愿者精神。他尤其强调要关爱并尊重教师,并通过公立教育使每个人的才智得到充分发展。 Lable: 教育

ERNIE实现新闻文本分类 - 项目链接

2.ERNIE 3.0--MSRA序列标注实战

  1. !pip install --upgrade paddlenlp
  2. !pip install -U paddlehub

PaddleHub2.0——使用动态图版预训练模型ERNIE实现序列标注

2.1. 什么是序列标注

序列标注(Sequence Tagging)是经典的自然语言处理问题,可以用于解决一系列字符分类问题,例如分词、词性标注(POS tagging)、命名实体识别(Named Entity Recognition,NER)、关键词抽取、语义角色标注(Semantic Role Labeling)、槽位抽取(Slot Filling)。在现实场景中,序列标注技术可以帮助完成简历、快递单、病例医疗实体信息抽取等。

在序列标注任务中,一般会定义一个标签集合来表示所有预测结果。对于输入的序列,任务目标是对序列中所有字符进行标记。在深度学习中,通常将序列标注问题视为分类问题,对输入序列的每一个token进行一次多分类任务进行训练预测。

ERNIE 3.0框架分为两层。第一层是通用语义表示网络,该网络学习数据中的基础和通用的知识。第二层是任务语义表示网络,该网络基于通用语义表示,学习任务相关的知识。在学习过程中,任务语义表示网络只学习对应类别的预训练任务,而通用语义表示网络会学习所有的预训练任务。

2.3. ERNIE 3.0中文预训练模型进行MSRA序列标注

2.3.1 环境准备

AI Studio平台默认安装了Paddle和PaddleNLP,并定期更新版本。 如需手动更新Paddle,可参考飞桨安装说明,安装相应环境下最新版飞桨框架。使用如下命令确保安装最新版PaddleNLP:

  1. import os
  2. import paddle
  3. import paddlenlp

2.3.2 加载MSRA-NER数据集

MSRA-NER 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,主要包括人名、地名、机构名等。PaddleNLP已经内置该数据集,一键即可加载。PaddleNLP集成的数据集MSRA-NER数据集对文件格式做了调整:每一行文本、标签以特殊字符"\t"进行分隔,每个字之间以特殊字符"\002"分隔。示例如下:

  1. \002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002O\002O\002O\002O\002B-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002O\002O\002O\002O\002O\002O\002O\002O\002B-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002O
  2. \002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002\002O\002O\002O\002O\002O\002O\002O\002O\002B-LOC\002I-LOC\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O

加载MSRA_NER数据集为BIO标注集:

  • B-PER、I-PER代表人名首字、人名非首字。
  • B-LOC、I-LOC代表地名首字、地名非首字。
  • B-ORG、I-ORG代表组织机构名首字、组织机构名非首字。
  • O代表该字不属于命名实体的一部分。
  1. # 加载MSRA_NER数据集
  2. from paddlenlp.datasets import load_dataset
  3. train_ds, test_ds = load_dataset('msra_ner', splits=('train', 'test'), lazy=False)
  4. label_vocab = {label:label_id for label_id, label in enumerate(train_ds.label_list)}
  5. # 数据集返回类型为MapDataset
  6. print("数据类型:", type(train_ds))
  7. print("数据标签:", label_vocab)
  8. # 每条数据包含一句文本和这个文本中每个汉字以及数字对应的label标签
  9. print("训练集样例:", train_ds[0])
  10. print("测试集样例:", test_ds[0])
  1. 数据类型: <class 'paddlenlp.datasets.dataset.MapDataset'>
  2. 数据标签: {'B-PER': 0, 'I-PER': 1, 'B-ORG': 2, 'I-ORG': 3, 'B-LOC': 4, 'I-LOC': 5, 'O': 6}
  3. 训练集样例: {'tokens': ['当', '希', '望', '工', '程', '救', '助', '的', '百', '万', '儿', '童', '成', '长', '起', '来', ',', '科', '教', '兴', '国', '蔚', '然', '成', '风', '时', ',', '今', '天', '有', '收', '藏', '价', '值', '的', '书', '你', '没', '买', ',', '明', '日', '就', '叫', '你', '悔', '不', '当', '初', '!'], 'labels': [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]}
  4. 测试集样例: {'tokens': ['中', '共', '中', '央', '致', '中', '国', '致', '公', '党', '十', '一', '大', '的', '贺', '词', '各', '位', '代', '表', '、', '各', '位', '同', '志', ':', '在', '中', '国', '致', '公', '党', '第', '十', '一', '次', '全', '国', '代', '表', '大', '会', '隆', '重', '召', '开', '之', '际', ',', '中', '国', '共', '产', '党', '中', '央', '委', '员', '会', '谨', '向', '大', '会', '表', '示', '热', '烈', '的', '祝', '贺', ',', '向', '致', '公', '党', '的', '同', '志', '们', '致', '以', '亲', '切', '的', '问', '候', '!'], 'labels': [2, 3, 3, 3, 6, 2, 3, 3, 3, 3, 3, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 6, 6, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]}

2.3.3 加载中文ERNIE 3.0预训练模型和分词器

PaddleNLP中Auto模块(包括AutoModel, AutoTokenizer及各种下游任务类)提供了方便易用的接口,无需指定模型类别,即可调用不同网络结构的预训练模型。PaddleNLP的预训练模型可以很容易地通过from_pretrained()方法加载,Transformer预训练模型汇总包含了40多个主流预训练模型,500多个模型权重。

AutoForTokenClassification可用于序列标注,通过预训练模型获取输入文本每个token的表示,之后将token表示进行分类。PaddleNLP已经实现了ERNIE 3.0预训练模型,可以通过一行代码实现ERNIE 3.0预训练模型和分词器的加载。

本项目开源 ERNIE 3.0 Base 、ERNIE 3.0 Medium 、 ERNIE 3.0 Mini 、 ERNIE 3.0 Micro 、 ERNIE 3.0 Nano 五个模型:

  1. from paddlenlp.transformers import AutoModelForTokenClassification
  2. from paddlenlp.transformers import AutoTokenizer
  3. model_name = "ernie-3.0-base-zh"
  4. model = AutoModelForTokenClassification.from_pretrained(model_name, num_classes=len(train_ds.label_list))
  5. tokenizer = AutoTokenizer.from_pretrained(model_name)

2.3.4 基于预训练模型的数据处理

Dataset中通常为原始数据,需要经过一定的数据处理转成可输入模型的数据并进行采样组batch。

  • 通过Datasetmap函数,使用分词器将数据集从原始文本处理成模型的输入。
  • 定义paddle.io.BatchSamplercollate_fn构建 paddle.io.DataLoader

实际训练中,根据显存大小调整批大小batch_size和文本最大长度max_seq_length

  1. import functools
  2. import numpy as np
  3. from paddle.io import DataLoader, BatchSampler
  4. from paddlenlp.data import DataCollatorForTokenClassification
  5. # 数据预处理函数,利用分词器将文本转化为整数序列
  6. def preprocess_function(example, tokenizer, label_vocab, max_seq_length=128):
  7. labels = example['labels']
  8. tokens = example['tokens']
  9. no_entity_id = label_vocab['O']
  10. tokenized_input = tokenizer(tokens, return_length=True, is_split_into_words=True, max_seq_len=max_seq_length)
  11. # 保证label与input_ids长度一致
  12. # -2 for [CLS] and [SEP]
  13. if len(tokenized_input['input_ids']) - 2 < len(labels):
  14. labels = labels[:len(tokenized_input['input_ids']) - 2]
  15. tokenized_input['labels'] = [no_entity_id] + labels + [no_entity_id]
  16. tokenized_input['labels'] += [no_entity_id] * (len(tokenized_input['input_ids']) - len(tokenized_input['labels']))
  17. return tokenized_input
  18. trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, label_vocab=label_vocab, max_seq_length=128)
  19. train_ds = train_ds.map(trans_func)
  20. test_ds = test_ds.map(trans_func)
  21. # collate_fn函数构造,将不同长度序列充到批中数据的最大长度,再将数据堆叠
  22. collate_fn = DataCollatorForTokenClassification(tokenizer=tokenizer, label_pad_token_id=-1)
  23. # 定义BatchSampler,选择批大小和是否随机乱序,进行DataLoader
  24. train_batch_sampler = BatchSampler(train_ds, batch_size=32, shuffle=True)
  25. test_batch_sampler = BatchSampler(test_ds, batch_size=32, shuffle=False)
  26. train_data_loader = DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=collate_fn)
  27. test_data_loader = DataLoader(dataset=test_ds, batch_sampler=test_batch_sampler, collate_fn=collate_fn)

2.3.5 数据训练和评估

定义训练所需的优化器、损失函数、评价指标等,就可以开始进行预模型微调任务。

  1. from paddlenlp.metrics import ChunkEvaluator
  2. # Adam优化器、交叉熵损失函数、ChunkEvaluator评价指标
  3. optimizer = paddle.optimizer.AdamW(learning_rate=2e-5, parameters=model.parameters())
  4. criterion = paddle.nn.loss.CrossEntropyLoss(ignore_index=-1)
  5. metric = ChunkEvaluator(label_list=train_ds.label_list)

10个epoch预计训练时间60分钟。

  1. # 开始训练
  2. import time
  3. import paddle.nn.functional as F
  4. from utils import evaluate
  5. epochs = 10 # 训练轮次
  6. ckpt_dir = "ernie_ckpt" #训练过程中保存模型参数的文件夹
  7. best_f1_score = 0
  8. best_step = 0
  9. global_step = 0 #迭代次数
  10. tic_train = time.time()
  11. for epoch in range(1, epochs + 1):
  12. for step, batch in enumerate(train_data_loader, start=1):
  13. input_ids, token_type_ids, labels = batch['input_ids'], batch['token_type_ids'], batch['labels']
  14. # 计算模型输出、损失函数值
  15. logits = model(input_ids, token_type_ids)
  16. loss = paddle.mean(criterion(logits, labels))
  17. # 每迭代10次,打印损失函数值、计算速度
  18. global_step += 1
  19. if global_step % 10 == 0:
  20. print(
  21. "global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s"
  22. % (global_step, epoch, step, loss, 10 / (time.time() - tic_train)))
  23. tic_train = time.time()
  24. # 反向梯度回传
  25. loss.backward()
  26. optimizer.step()
  27. optimizer.clear_grad()
  28. # 每迭代200次,评估当前训练的模型、保存当前最佳模型参数和分词器的词表等
  29. if global_step % 200 == 0:
  30. save_dir = ckpt_dir
  31. if not os.path.exists(save_dir):
  32. os.makedirs(save_dir)
  33. print('global_step', global_step, end=' ')
  34. f1_score_eval = evaluate(model, metric, test_data_loader)
  35. if f1_score_eval > best_f1_score:
  36. best_f1_score = f1_score_eval
  37. best_step = global_step
  38. model.save_pretrained(save_dir)
  39. tokenizer.save_pretrained(save_dir)
  1. global step 13960, epoch: 10, batch: 1297, loss: 0.00087, speed: 5.29 step/s
  2. global step 13970, epoch: 10, batch: 1307, loss: 0.01934, speed: 5.31 step/s
  3. global step 13980, epoch: 10, batch: 1317, loss: 0.00257, speed: 5.27 step/s
  4. global step 13990, epoch: 10, batch: 1327, loss: 0.00045, speed: 4.94 step/s
  5. global step 14000, epoch: 10, batch: 1337, loss: 0.00306, speed: 5.04 step/s
  6. global_step 14000 eval precision: 0.946141 - recall: 0.956587 - f1: 0.951335
  7. global step 14010, epoch: 10, batch: 1347, loss: 0.00233, speed: 0.93 step/s
  8. global step 14020, epoch: 10, batch: 1357, loss: 0.00579, speed: 5.05 step/s
  9. global step 14030, epoch: 10, batch: 1367, loss: 0.00169, speed: 5.24 step/s
  10. global step 14040, epoch: 10, batch: 1377, loss: 0.00024, speed: 4.86 step/s
  11. global step 14050, epoch: 10, batch: 1387, loss: 0.00112, speed: 5.44 step/s
  12. global step 14060, epoch: 10, batch: 1397, loss: 0.00163, speed: 5.10 step/s
  13. global step 14070, epoch: 10, batch: 1407, loss: 0.00009, speed: 5.55 step/s
  14. 运行时长:3460.397秒结束时间:2022-06-01 16:52:00

 utils:文件需要下载:

链接:ERNIE3.0中文预训练模型进行MSRA序列标注-自然语言处理文档类资源-

2.3.6 序列标注结果预测与保存

加载微调好的模型参数进行情感分析预测,并保存预测结果

  1. # 测试集结果评估
  2. from utils import parse_decodes
  3. # 加载最佳模型参数
  4. model.set_dict(paddle.load('ernie_ckpt/model_state.pdparams'))
  5. # 可以加载预先训练好的模型参数结果查看模型训练结果
  6. # model.set_dict(paddle.load('ernie_ckpt_trained/model_state.pdparams'))
  7. model.eval()
  8. metric.reset()
  9. pred_list = []
  10. len_list = []
  11. for step, batch in enumerate(test_data_loader, start=1):
  12. input_ids, token_type_ids, labels, lens = batch['input_ids'], batch['token_type_ids'], batch['labels'], batch['seq_len']
  13. logits = model(input_ids, token_type_ids)
  14. preds = paddle.argmax(logits, axis=-1)
  15. n_infer, n_label, n_correct = metric.compute(lens, preds, labels)
  16. metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy())
  17. pred_list.append(preds.numpy())
  18. len_list.append(lens.numpy())
  19. precision, recall, f1_score = metric.accumulate()
  20. print("ERNIE 3.0 在msra_ner的test集表现 -precision: %f - recall: %f - f1: %f" %(precision, recall, f1_score))

结果:

  1. ERNIE 3.0 msra_nertest集表现 -precision: 0.948490 - recall: 0.957897 - f1: 0.953170

我们根据模型预测结果对文本进行后处理,对文本序列进行标注,具体的标签含义如下:

  • ‘O’: no special entity(其他不属于任何实体的字符,包括标点等)
  • ‘PER’: person(人名)
  • ‘ORG’: organization(组织机构)
  • ‘LOC’: location(地点)
  1. # 根据预测结果对文本进行后处理
  2. test_ds = load_dataset('msra_ner', splits=('test'))
  3. preds = parse_decodes(test_ds, pred_list, len_list, label_vocab)
  4. # 查看预测结果
  5. print("查看结果")
  6. print("\n".join(preds[:2]))
  7. # 保存预测结果
  8. with open("results.txt", "w", encoding="utf-8") as f:
  9. f.write("\n".join(preds))

结果:

  1. 查看结果
  2. ('中共中央', 'ORG')('致', 'O')('中国致公党十一大', 'ORG')('的', 'O')('贺', 'O')('词', 'O')('各', 'O')('位', 'O')('代', 'O')('表', 'O')('、', 'O')('各', 'O')('位', 'O')('同', 'O')('志', 'O')(':', 'O')('在', 'O')('中国致公党第十一次全国代表大会', 'ORG')('隆', 'O')('重', 'O')('召', 'O')('开', 'O')('之', 'O')('际', 'O')(',', 'O')('中国***中央委员会', 'ORG')('谨', 'O')('向', 'O')('大', 'O')('会', 'O')('表', 'O')('示', 'O')('热', 'O')('烈', 'O')('的', 'O')('祝', 'O')('贺', 'O')(',', 'O')('向', 'O')('致公党', 'ORG')('的', 'O')('同', 'O')('志', 'O')('们', 'O')('致', 'O')('以', 'O')('亲', 'O')('切', 'O')('的', 'O')('问', 'O')('候', 'O')('!', 'O')
  3. ('在', 'O')('过', 'O')('去', 'O')('的', 'O')('五', 'O')('年', 'O')('中', 'O')(',', 'O')('致公党', 'ORG')('在', 'O')('*', 'PER')('理', 'O')('论', 'O')('指', 'O')('引', 'O')('下', 'O')(',', 'O')('遵', 'O')('循', 'O')('社', 'O')('会', 'O')('主', 'O')('义', 'O')('初', 'O')('级', 'O')('阶', 'O')('段', 'O')('的', 'O')('基', 'O')('本', 'O')('路', 'O')('线', 'O')(',', 'O')('努', 'O')('力', 'O')('实', 'O')('践', 'O')('致公党十大', 'ORG')('提', 'O')('出', 'O')('的', 'O')('发', 'O')('挥', 'O')('参', 'O')('政', 'O')('党', 'O')('职', 'O')('能', 'O')('、', 'O')('加', 'O')('强', 'O')('自', 'O')('身', 'O')('建', 'O')('设', 'O')('的', 'O')('基', 'O')('本', 'O')('任', 'O')('务', 'O')('。', 'O')

欢迎一键三联!

PaddleHub实战篇{ERNIE实现文新闻本分类、ERNIE3.0 实现序列标注}【四】的更多相关文章

  1. 构建NetCore应用框架之实战篇(五):BitAdminCore框架1.0登录功能设计实现及源码

    本篇承接上篇内容,如果你不小心点击进来,建议从第一篇开始完整阅读,文章内容继承性连贯性. 构建NetCore应用框架之实战篇系列 一.设计原则 1.继承前面框架架构思维,设计以可读性作为首要目标. 2 ...

  2. 构建NetCore应用框架之实战篇系列

    构建NetCore应用框架之实战篇 构建NetCore应用框架之实战篇(一):什么是框架,如何设计一个框架 构建NetCore应用框架之实战篇(二):BitAdminCore框架定位及架构 构建Net ...

  3. 二、Redis基本操作——String(实战篇)

    小喵万万没想到,上一篇博客,居然已经被阅读600次了!!!让小喵感觉压力颇大.万一有写错的地方,岂不是会误导很多筒子们.所以,恳请大家,如果看到小喵的博客有什么不对的地方,请尽快指正!谢谢! 小喵的唠 ...

  4. Systemd 入门教程:实战篇

    Systemd 入门教程:实战篇 上一篇文章,介绍了 Systemd 的主要命令,这篇文章主要介绍如何使用 Systemd 来管理我们的服务,以及各项的含义: 一.开机启动 对于那些支持 System ...

  5. 02_HTML5+CSS3详解第五、六天(实战篇之HTML5制作企业网站)

    [废话连篇 - 实战篇,没什么好说的,最后一章兼容性问题懒得看了,over] Details 一.Xmind部分 xmind教程:http://www.jianshu.com/p/7c488d5e4b ...

  6. Jenkins插件安装实战篇

    Jenkins插件安装实战篇 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 上篇博客我介绍了Jenkins是啥,以及持续集成,持续交付,持续部署的概念,那么问题来了:你知道CI和C ...

  7. 《黑客攻防技术宝典Web实战篇@第2版》读书笔记1:了解Web应用程序

    读书笔记第一部分对应原书的第一章,主要介绍了Web应用程序的发展,功能,安全状况. Web应用程序的发展历程 早期的万维网仅由Web站点构成,只是包含静态文档的信息库,随后人们发明了Web浏览器用来检 ...

  8. Python和Flask真强大:不能错过的15篇技术热文(转载)

    Python和Flask真强大:不能错过的15篇技术热文 本文精选了 Python开发者 11月份的15篇 Python 热文.其中有基础知识,机器学习,爬虫项目实战等. 注:以下文章,点击标题即可阅 ...

  9. 金蝶随手记团队分享:还在用JSON? Protobuf让数据传输更省更快(实战篇)

    本文作者:丁同舟,来自金蝶随手记技术团队. 1.前言 本文接上篇<金蝶随手记团队分享:还在用JSON? Protobuf让数据传输更省更快(原理篇)>,以iOS端的Objective-C代 ...

  10. java并发系列 - 第28天:实战篇,微服务日志的伤痛,一并帮你解决掉

    这是java高并发系列第28篇文章. 环境:jdk1.8. 本文内容 日志有什么用? 日志存在的痛点? 构建日志系统 日志有什么用? 系统出现故障的时候,可以通过日志信息快速定位问题,修复bug,恢复 ...

随机推荐

  1. Safari 14.0 的功臣 Webp?

    俗话说:一图胜千言.在网上,图片虽然可以让用户更加简单明了地看到更多信息,但是图片体积也可以抵过上千字节甚至更多.研究表明,打开一个 HTTP 网页,其中图片平均占比为 64%.在图片占比如此高的情况 ...

  2. three.js 火焰效果

    代码是网上找的代码,通过调参.修改.封装实现的. 代码: /** * 火焰 */ import * as THREE from '../build/three.module.js'; let MyFi ...

  3. Markdown 语法:高级技巧

    Markdown 高级技巧 支持的 HTML 元素 不在 Markdown 涵盖范围之内的标签,多可以直接在文档里面用 HTML 撰写. 目前支持的 HTML 标签有 <kbd>,< ...

  4. Codeforce 1288C. Two Arrays(DP组合数学,n个数选择m个数,单调不递减个数,排列组合打表N*N)

    https://codeforces.com/problemset/problem/1288/C Examples input 2 2 output 5 input 10 1 output 55 in ...

  5. BZOJ 2038: [2009国家集训队]小Z的袜子【莫队算法裸题】

    作为一个生活散漫的人,小Z每天早上都要耗费很久从一堆五颜六色的袜子中找出一双来穿. 终于有一天,小Z再也无法忍受这恼人的找袜子过程,于是他决定听天由命. 具体来说,小Z把这N只袜子从1到N编号,然后从 ...

  6. FZU 2232

    ***题意:求最大匹配是否为n 今天突然想起来吧模板改一下,然而自己得想法不对,WA了有十多次吧,看了一下题解,不需要改,套上模板就行*** #include<stdio.h> #incl ...

  7. python之configparser类的使用

    一.定义配置文件格式如下:data.conf [interface] url=http://192.168.37.8:7777/api/mytest2 [switch] switch_car=on [ ...

  8. MySQL本地服务器与MySQL57网络服务器区别

    MySQL服务器与MySQL57服务器区别与不同处在哪里,他们各自的领域范围,能不能同时启动服务? 安装了MySQL-5.7.18.0版本数据库,版本中包含了MySQL Workbench可视化试图工 ...

  9. 2. 成功使用SQL Plus完成连接,但在使用Oracle SQL Developer连接时,发生报错ORA-12526: TNS:listener: all appropriate instances are in restricted mode

    经了解后得知,错误原因:ORA-12526: TNS: 监听程序: 所有适用例程都处于受限模式. 解决办法:使用系统管理员身份运行以下一段代码 ALTER SYSTEM DISABLE RESTRIC ...

  10. SQL函数——时间函数

    1.使用 NOW() . CURDATE().CURTIME() 获取当前时间 在这里我有一个问题想问问大家,你们平时都是怎么样子获取时间的呢?是不是通过手表.手机.电脑等设备了解到的,那么你们有没有 ...