Transformers 库常见的用例 | 三
作者|huggingface
编译|VK
来源|Github
本章介绍使用Transformers库时最常见的用例。可用的模型允许许多不同的配置,并且在用例中具有很强的通用性。这里介绍了最简单的方法,展示了诸如问答、序列分类、命名实体识别等任务的用法。
这些示例利用Auto Model
,这些类将根据给定的checkpoint实例化模型,并自动选择正确的模型体系结构。有关详细信息,请查看:AutoModel
文档。请随意修改代码,使其更具体,并使其适应你的特定用例。
- 为了使模型能够在任务上良好地执行,必须从与该任务对应的checkpoint加载模型。这些checkpoint通常是在大量数据上预先训练的,并针对特定任务进行微调。这意味着:并非所有模型都针对所有任务进行了微调。如果要对特定任务的模型进行微调,可以利用examples目录中的
run\$task.py
脚本。 - 微调模型是在特定的数据集上微调的。此数据集可能与你的用例和域重叠,也可能不重叠。如前所述,你可以利用示例脚本来微调模型,也可以创建自己的训练脚本。
为了对任务进行推理,库提供了几种机制:
- 管道是非常易于使用的抽象,只需要两行代码。
- 直接将模型与Tokenizer(PyTorch/TensorFlow)结合使用来使用模型的完整推理。这种机制稍微复杂,但是更强大。
这里展示了两种方法。
请注意,这里介绍的所有任务都利用了在预训练模型针对特定任务进行微调后的模型。加载未针对特定任务进行微调的checkpoint时,将只加载transformer层,而不会加载用于该任务的附加层,从而随机初始化该附加层的权重。这将产生随机输出。
序列分类
序列分类是根据已经给定的类别然后对序列进行分类的任务。序列分类的一个例子是GLUE数据集,它就是完全基于该任务的。如果你想在GLUE序列分类任务上微调模型,可以利用run_GLUE.py
或run_tf_GLUE.py
脚本。
下面是一个使用管道进行情绪分析的例子:识别该序列是积极的还是消极的。它利用sst2上的微调模型,这是一个GLUE任务。
from transformers import pipeline
nlp = pipeline("sentiment-analysis")
print(nlp("I hate you"))
print(nlp("I love you"))
这将返回一个标签(“积极”或“消极”)和一个分数,如下所示:
[{'label': 'NEGATIVE', 'score': 0.9991129}]
[{'label': 'POSITIVE', 'score': 0.99986565}]
下面是一个使用模型进行序列分类的示例,以确定两个序列是否是彼此的解释。该过程如下:
- 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
- 从这两句话中构建一个序列,使用正确的特定于模型的分隔符标记类型id和注意力掩码(encode()和encode_plus()处理这个问题)
- 将这个序列传递到模型中,以便将其分类到两个可用的类中的一个:0(不是解释)和1(是解释)
- 计算结果的softmax获取类的概率
- 打印结果
Pytorch代码
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt")
paraphrase_classification_logits = model(**paraphrase)[0]
not_paraphrase_classification_logits = model(**not_paraphrase)[0]
paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]
print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
TensorFlow代码
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="tf")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="tf")
paraphrase_classification_logits = model(paraphrase)[0]
not_paraphrase_classification_logits = model(not_paraphrase)[0]
paraphrase_results = tf.nn.softmax(paraphrase_classification_logits, axis=1).numpy()[0]
not_paraphrase_results = tf.nn.softmax(not_paraphrase_classification_logits, axis=1).numpy()[0]
print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
这将输出以下结果:
Should be paraphrase
not paraphrase: 10%
is paraphrase: 90%
Should not be paraphrase
not paraphrase: 94%
is paraphrase: 6%
抽取式问答
抽取式问答是从给定问题的文本中抽取答案的任务。问答数据集的一个例子是SQuAD数据集,它完全基于该任务。如果你想在团队任务中微调模型,可以利用run_SQuAD.py
。
下面是一个使用管道进行问答的示例:从给定问题的文本中提取答案。它利用了一个小队的微调模型。
from transformers import pipeline
nlp = pipeline("question-answering")
context = r"""
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune
a model on a SQuAD task, you may leverage the `run_squad.py`.
"""
print(nlp(question="What is extractive question answering?", context=context))
print(nlp(question="What is a good example of a question answering dataset?", context=context))
这将返回从文本中提取的答案,一个置信度,以及“开始”和“结束”值,这些值是提取的答案在文本中的位置。
{'score': 0.622232091629833, 'start': 34, 'end': 96, 'answer': 'the task of extracting an answer from a text given a question.'}
{'score': 0.5115299158662765, 'start': 147, 'end': 161, 'answer': 'SQuAD dataset,'}
下面是一个使用模型和Tokenizer回答问题的示例。该过程如下:
- 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
- 定义一段文本和几个问题。
- 遍历问题并根据文本和当前问题构建一个序列,使用正确的模型特定分隔符标记类型id和注意力掩码将此序列传递到模型中。这将输出整个序列标记(问题和文本)的开始位置和结束位置的一系列分数。
- 计算结果的softmax以获得从标记的开始位置和停止位置对应的概率
- 将这些标记转换为字符串。
- 打印结果
Pytorch代码
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
text = r"""
Transformers 库常见的用例 | 三的更多相关文章
写这个专栏的缘起 之前我写过一篇博客:<朱晔的互联网架构实践心得S2E2:写业务代码最容易掉的10种坑>,引起的关注还是挺多的.后来和极客时间的编辑一拍即合决定以这个为题写一个专栏.其实所 ...
shell常见脚本30例 author:headsen chen 2017-10-19 10:12:12 本文原素材出自网上,特此申明.有些地方加入我自己的改动 常见的30例shell脚本 1.用 ...
MySQL的库.表的详细操作 MySQL数据库 本节目录 一 库操作 二 表操作 三 行操作 一 库操作 1.创建数据库 1.1 语法 CREATE DATABASE 数据库名 charset utf ...
常见MIME类型例表: 序号 内容类型 文件扩展名 描述 1 application/msword doc Microsoft Word 2 application/octet-stream bin ...
常见的装包的三种宝,包 bao-devel bao-utils bao-agent ,包 开发包 工具包 客户端
Q:VC中引用第三方库,常见的库冲突问题 环境:[1]VS2008 [2]WinXP SP3 A1(方法一): [S1]第三方库(Binary形式的)如果同主程序冲突,则下载第三方库的源码[S2]保持 ...
1.TCP常见的定时器 在TCP协议中有的时候需要定期或者按照某个算法对某个事件进行触发,那么这个时候,TCP协议是使用定时器进行实现的.在TCP中,会有七种定时器: 建立连接定时器(connecti ...
1. 为什么选择Aero2 除了以外观为卖点的控件库,WPF的控件库都默认使用"素颜"的外观,然后再提供一些主题包.这样做的最大好处是可以和原生控件或其它控件库兼容,而且对于大部分 ...
1.附阿里图标库链接:http://www.iconfont.cn/ 2.登录阿里图标库以后,搜索我们需要的图标,将其加入购物车,如图3.将我们需要的图标全部挑选完毕以后,点击购物车图标4.这时候右侧 ...
随机推荐
本文主要记录CentOS下FTP Server的安装和配置流程. 安装vsftpd yum install -y vsftpd 启动vsftpd service vsftpd start 运行下面的命 ...
在使用echart过程中,toolbox里有个dataView视图模式,里面的数据没有对整,影响展示效果,情形如下:改问题解决方案为,在optionTocontent回调函数中处理,具体代码如下: t ...
一.整体大纲 二.基础知识 1. 进程相关概念 1)程序和进程 程序,是指编译好的二进制文件,在磁盘上,不占用系统资源(cpu.内存.打开的文件.设备.锁....) 进程,是一个抽象的概念,与 ...
经过这么多年的发展,JavaScript 早已经不是当年那个不太起眼的脚本语言.如今的 JavaScript 可以说是风光无限,在 Web 前端.移动端.服务端甚至物联网设备上都大展身手,到处都有它的 ...
Specifications动态查询 有时我们在查询某个实体的时候,给定的条件是不固定的,这时就需要动态构建相应的查询语句,在Spring Data JPA中可以通过JpaSpecificationE ...
SQL Injection SQL注入 Abstract 通过不可信来源的输入构建动态 SQL 指令,攻击者就能够修改指令的含义或者执行任意 SQL 命令. Explanation SQL injec ...
git密令是一种非常好用的代码版本管理工具,相比SVN,Sourcetree 使用起来复杂,主要是没有汉化包,当你使用熟练时,其实也是非常简单的,逼格高. 具体使用如下: 情景一:你只有远程库,没有本 ...
目录 Vue2.0 [第二季]第8节 Component 父子组件关系 第8节 Component 父子组件关系 一.构造器外部写局部注册组件 二.父子组件的嵌套 Vue2.0 [第二季]第8节 Co ...
index = ~~this.userIndex ~~ 双破折号 如果是数字返回数字,如果不是数字 返回0 这个运算符有点意思:按位非[~] 先来几个例子: ~undefined: -1 ~false ...
第一步:用户输入网址进入一个登陆界面. 里面要有账号密码输入. 登陆界面链接到登陆的Servlet类中. Servlet类 --> 1.接收参数(账户密码) 2.调用DAO层的 SQL语句 验 ...