Transformers 库常见的用例 | 三
作者|huggingface
编译|VK
来源|Github
本章介绍使用Transformers库时最常见的用例。可用的模型允许许多不同的配置,并且在用例中具有很强的通用性。这里介绍了最简单的方法,展示了诸如问答、序列分类、命名实体识别等任务的用法。
这些示例利用Auto Model
,这些类将根据给定的checkpoint实例化模型,并自动选择正确的模型体系结构。有关详细信息,请查看:AutoModel
文档。请随意修改代码,使其更具体,并使其适应你的特定用例。
- 为了使模型能够在任务上良好地执行,必须从与该任务对应的checkpoint加载模型。这些checkpoint通常是在大量数据上预先训练的,并针对特定任务进行微调。这意味着:并非所有模型都针对所有任务进行了微调。如果要对特定任务的模型进行微调,可以利用examples目录中的
run\$task.py
脚本。 - 微调模型是在特定的数据集上微调的。此数据集可能与你的用例和域重叠,也可能不重叠。如前所述,你可以利用示例脚本来微调模型,也可以创建自己的训练脚本。
为了对任务进行推理,库提供了几种机制:
- 管道是非常易于使用的抽象,只需要两行代码。
- 直接将模型与Tokenizer(PyTorch/TensorFlow)结合使用来使用模型的完整推理。这种机制稍微复杂,但是更强大。
这里展示了两种方法。
请注意,这里介绍的所有任务都利用了在预训练模型针对特定任务进行微调后的模型。加载未针对特定任务进行微调的checkpoint时,将只加载transformer层,而不会加载用于该任务的附加层,从而随机初始化该附加层的权重。这将产生随机输出。
序列分类
序列分类是根据已经给定的类别然后对序列进行分类的任务。序列分类的一个例子是GLUE数据集,它就是完全基于该任务的。如果你想在GLUE序列分类任务上微调模型,可以利用run_GLUE.py
或run_tf_GLUE.py
脚本。
下面是一个使用管道进行情绪分析的例子:识别该序列是积极的还是消极的。它利用sst2上的微调模型,这是一个GLUE任务。
from transformers import pipeline
nlp = pipeline("sentiment-analysis")
print(nlp("I hate you"))
print(nlp("I love you"))
这将返回一个标签(“积极”或“消极”)和一个分数,如下所示:
[{'label': 'NEGATIVE', 'score': 0.9991129}]
[{'label': 'POSITIVE', 'score': 0.99986565}]
下面是一个使用模型进行序列分类的示例,以确定两个序列是否是彼此的解释。该过程如下:
- 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
- 从这两句话中构建一个序列,使用正确的特定于模型的分隔符标记类型id和注意力掩码(encode()和encode_plus()处理这个问题)
- 将这个序列传递到模型中,以便将其分类到两个可用的类中的一个:0(不是解释)和1(是解释)
- 计算结果的softmax获取类的概率
- 打印结果
Pytorch代码
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt")
paraphrase_classification_logits = model(**paraphrase)[0]
not_paraphrase_classification_logits = model(**not_paraphrase)[0]
paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]
print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
TensorFlow代码
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="tf")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="tf")
paraphrase_classification_logits = model(paraphrase)[0]
not_paraphrase_classification_logits = model(not_paraphrase)[0]
paraphrase_results = tf.nn.softmax(paraphrase_classification_logits, axis=1).numpy()[0]
not_paraphrase_results = tf.nn.softmax(not_paraphrase_classification_logits, axis=1).numpy()[0]
print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
这将输出以下结果:
Should be paraphrase
not paraphrase: 10%
is paraphrase: 90%
Should not be paraphrase
not paraphrase: 94%
is paraphrase: 6%
抽取式问答
抽取式问答是从给定问题的文本中抽取答案的任务。问答数据集的一个例子是SQuAD数据集,它完全基于该任务。如果你想在团队任务中微调模型,可以利用run_SQuAD.py
。
下面是一个使用管道进行问答的示例:从给定问题的文本中提取答案。它利用了一个小队的微调模型。
from transformers import pipeline
nlp = pipeline("question-answering")
context = r"""
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune
a model on a SQuAD task, you may leverage the `run_squad.py`.
"""
print(nlp(question="What is extractive question answering?", context=context))
print(nlp(question="What is a good example of a question answering dataset?", context=context))
这将返回从文本中提取的答案,一个置信度,以及“开始”和“结束”值,这些值是提取的答案在文本中的位置。
{'score': 0.622232091629833, 'start': 34, 'end': 96, 'answer': 'the task of extracting an answer from a text given a question.'}
{'score': 0.5115299158662765, 'start': 147, 'end': 161, 'answer': 'SQuAD dataset,'}
下面是一个使用模型和Tokenizer回答问题的示例。该过程如下:
- 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
- 定义一段文本和几个问题。
- 遍历问题并根据文本和当前问题构建一个序列,使用正确的模型特定分隔符标记类型id和注意力掩码将此序列传递到模型中。这将输出整个序列标记(问题和文本)的开始位置和结束位置的一系列分数。
- 计算结果的softmax以获得从标记的开始位置和停止位置对应的概率
- 将这些标记转换为字符串。
- 打印结果
Pytorch代码
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
text = r"""
Transformers 库常见的用例 | 三的更多相关文章
- Java避坑宝典《Java业务开发常见错误100例》上线了
写这个专栏的缘起 之前我写过一篇博客:<朱晔的互联网架构实践心得S2E2:写业务代码最容易掉的10种坑>,引起的关注还是挺多的.后来和极客时间的编辑一拍即合决定以这个为题写一个专栏.其实所 ...
- shell常见脚本30例
shell常见脚本30例 author:headsen chen 2017-10-19 10:12:12 本文原素材出自网上,特此申明.有些地方加入我自己的改动 常见的30例shell脚本 1.用 ...
- {MySQL的库、表的详细操作}一 库操作 二 表操作 三 行操作
MySQL的库.表的详细操作 MySQL数据库 本节目录 一 库操作 二 表操作 三 行操作 一 库操作 1.创建数据库 1.1 语法 CREATE DATABASE 数据库名 charset utf ...
- 常见MIME类型例表
常见MIME类型例表: 序号 内容类型 文件扩展名 描述 1 application/msword doc Microsoft Word 2 application/octet-stream bin ...
- 常见的装包的三种宝,包 bao-devel bao-utils bao-agent ,包 开发包 工具包 客户端
常见的装包的三种宝,包 bao-devel bao-utils bao-agent ,包 开发包 工具包 客户端
- VC中引用第三方库,常见的库冲突问题
Q:VC中引用第三方库,常见的库冲突问题 环境:[1]VS2008 [2]WinXP SP3 A1(方法一): [S1]第三方库(Binary形式的)如果同主程序冲突,则下载第三方库的源码[S2]保持 ...
- TCP常见的定时器及三次握手与四次挥手
1.TCP常见的定时器 在TCP协议中有的时候需要定期或者按照某个算法对某个事件进行触发,那么这个时候,TCP协议是使用定时器进行实现的.在TCP中,会有七种定时器: 建立连接定时器(connecti ...
- [WPF自定义控件库]以Button为例谈谈如何模仿Aero2主题
1. 为什么选择Aero2 除了以外观为卖点的控件库,WPF的控件库都默认使用"素颜"的外观,然后再提供一些主题包.这样做的最大好处是可以和原生控件或其它控件库兼容,而且对于大部分 ...
- 关于阿里图标库Iconfont生成图标的三种使用方式(fontclass/unicode/symbol)
1.附阿里图标库链接:http://www.iconfont.cn/ 2.登录阿里图标库以后,搜索我们需要的图标,将其加入购物车,如图3.将我们需要的图标全部挑选完毕以后,点击购物车图标4.这时候右侧 ...
随机推荐
- NVARCHAR(MAX) 的最大长度
本文使用的环境是SQL Server 2017, 主机是64位操作系统. 大家都知道,Micorosoft Docs对 max参数的定义是:max 指定最大的存储空间是2GB,该注释是不严谨的: nv ...
- C++走向远洋——36(数组做数据成员,工资)
*/ * Copyright (c) 2016,烟台大学计算机与控制工程学院 * All rights reserved. * 文件名:salarly.cpp * 作者:常轩 * 微信公众号:Worl ...
- [2020.03]Unity ML-Agents v0.15.0 环境部署与试运行
一.ML-Agents简介 近期在学习Unity中的机器学习插件ML-Agents,做一些记录,用以简单记录或交流学习. 先简单说一下机器学习使用的环境场景:高视觉复杂度(Visual Complex ...
- 【读后感】《Java编程思想》~ 异常
[读后感]<Java编程思想>~异常 终于拿出压箱底的那本<Java编程思想>.这本书我年轻的时候就买了,但是翻过几页后就放弃了.没想到这两天翻了一下,真的有收获. 看了一下第 ...
- Kubernetes Jenkins动态创建Slave
目录 0.前言 1.Jenkins部署 2.配置jenkins动态slave 3.dubbo服务构建 3.1.制作dubbo镜像底包 3.2.制作slave基础镜像 3.2.1.Maven镜像 3.2 ...
- Python进阶练习与爬取豆瓣T250的影片相关信息
(一)Python进阶练习 正所谓要将知识进行实践,才会真正的掌握 于是就练习了几道题:求素数,求奇数,求九九乘法表,字符串练习 import re #求素数 i=1; flag=0 while(i& ...
- 【面试经验分享】java面试中的那些潜规则
1.大纲 潜规则1:面试的本质不是考试,而是告诉面试官你会做什么 很多刚入行的小伙伴特别容易犯的一个错误,不清楚面试官到底想问什么,其实整个面试中面试官并没有想难道你的意思,只是想通过提问的方式来知道 ...
- xpath提取标签和内容
转:https://segmentfault.com/q/1010000012110138/a-1020000012113020 <div> <table> <tr> ...
- Python自定义模块
自定义模块 自定义模块(也就是私人订制),我们要自定义模块,首先就要知道什么是模块 一个函数封装一个功能,比如现在有一个软件,不可能将所有程序都写入一个文件,所以咱们应该分文件,组织结构要好,代码不冗 ...
- C++ 判断两个圆是否有交集
#define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include <math.h> #include <easyx.h ...