作者|huggingface

编译|VK

来源|Github

本章介绍使用Transformers库时最常见的用例。可用的模型允许许多不同的配置,并且在用例中具有很强的通用性。这里介绍了最简单的方法,展示了诸如问答、序列分类、命名实体识别等任务的用法。

这些示例利用Auto Model,这些类将根据给定的checkpoint实例化模型,并自动选择正确的模型体系结构。有关详细信息,请查看:AutoModel文档。请随意修改代码,使其更具体,并使其适应你的特定用例。

  • 为了使模型能够在任务上良好地执行,必须从与该任务对应的checkpoint加载模型。这些checkpoint通常是在大量数据上预先训练的,并针对特定任务进行微调。这意味着:并非所有模型都针对所有任务进行了微调。如果要对特定任务的模型进行微调,可以利用examples目录中的run\$task.py脚本。
  • 微调模型是在特定的数据集上微调的。此数据集可能与你的用例和域重叠,也可能不重叠。如前所述,你可以利用示例脚本来微调模型,也可以创建自己的训练脚本。

为了对任务进行推理,库提供了几种机制:

  • 管道是非常易于使用的抽象,只需要两行代码。
  • 直接将模型与Tokenizer(PyTorch/TensorFlow)结合使用来使用模型的完整推理。这种机制稍微复杂,但是更强大。

这里展示了两种方法。

请注意,这里介绍的所有任务都利用了在预训练模型针对特定任务进行微调后的模型。加载未针对特定任务进行微调的checkpoint时,将只加载transformer层,而不会加载用于该任务的附加层,从而随机初始化该附加层的权重。这将产生随机输出。

序列分类

序列分类是根据已经给定的类别然后对序列进行分类的任务。序列分类的一个例子是GLUE数据集,它就是完全基于该任务的。如果你想在GLUE序列分类任务上微调模型,可以利用run_GLUE.pyrun_tf_GLUE.py脚本。

下面是一个使用管道进行情绪分析的例子:识别该序列是积极的还是消极的。它利用sst2上的微调模型,这是一个GLUE任务。

from transformers import pipeline

nlp = pipeline("sentiment-analysis")

print(nlp("I hate you"))
print(nlp("I love you"))

这将返回一个标签(“积极”或“消极”)和一个分数,如下所示:

[{'label': 'NEGATIVE', 'score': 0.9991129}]
[{'label': 'POSITIVE', 'score': 0.99986565}]

下面是一个使用模型进行序列分类的示例,以确定两个序列是否是彼此的解释。该过程如下:

  • 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
  • 从这两句话中构建一个序列,使用正确的特定于模型的分隔符标记类型id和注意力掩码(encode()和encode_plus()处理这个问题)
  • 将这个序列传递到模型中,以便将其分类到两个可用的类中的一个:0(不是解释)和1(是解释)
  • 计算结果的softmax获取类的概率
  • 打印结果

Pytorch代码

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc") classes = ["not paraphrase", "is paraphrase"] sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan" paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt") paraphrase_classification_logits = model(**paraphrase)[0]
not_paraphrase_classification_logits = model(**not_paraphrase)[0] paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0] print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%") print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")

TensorFlow代码

from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc") classes = ["not paraphrase", "is paraphrase"] sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan" paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="tf")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="tf") paraphrase_classification_logits = model(paraphrase)[0]
not_paraphrase_classification_logits = model(not_paraphrase)[0] paraphrase_results = tf.nn.softmax(paraphrase_classification_logits, axis=1).numpy()[0]
not_paraphrase_results = tf.nn.softmax(not_paraphrase_classification_logits, axis=1).numpy()[0] print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%") print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")

这将输出以下结果:

Should be paraphrase
not paraphrase: 10%
is paraphrase: 90% Should not be paraphrase
not paraphrase: 94%
is paraphrase: 6%

抽取式问答

抽取式问答是从给定问题的文本中抽取答案的任务。问答数据集的一个例子是SQuAD数据集,它完全基于该任务。如果你想在团队任务中微调模型,可以利用run_SQuAD.py

下面是一个使用管道进行问答的示例:从给定问题的文本中提取答案。它利用了一个小队的微调模型。

from transformers import pipeline

nlp = pipeline("question-answering")

context = r"""
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune
a model on a SQuAD task, you may leverage the `run_squad.py`.
""" print(nlp(question="What is extractive question answering?", context=context))
print(nlp(question="What is a good example of a question answering dataset?", context=context))

这将返回从文本中提取的答案,一个置信度,以及“开始”和“结束”值,这些值是提取的答案在文本中的位置。

{'score': 0.622232091629833, 'start': 34, 'end': 96, 'answer': 'the task of extracting an answer from a text given a question.'}
{'score': 0.5115299158662765, 'start': 147, 'end': 161, 'answer': 'SQuAD dataset,'}

下面是一个使用模型和Tokenizer回答问题的示例。该过程如下:

  • 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
  • 定义一段文本和几个问题。
  • 遍历问题并根据文本和当前问题构建一个序列,使用正确的模型特定分隔符标记类型id和注意力掩码将此序列传递到模型中。这将输出整个序列标记(问题和文本)的开始位置和结束位置的一系列分数。
  • 计算结果的softmax以获得从标记的开始位置和停止位置对应的概率
  • 将这些标记转换为字符串。
  • 打印结果

Pytorch代码

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") text = r"""

Transformers 库常见的用例 | 三的更多相关文章

  1. Java避坑宝典《Java业务开发常见错误100例》上线了

    写这个专栏的缘起 之前我写过一篇博客:<朱晔的互联网架构实践心得S2E2:写业务代码最容易掉的10种坑>,引起的关注还是挺多的.后来和极客时间的编辑一拍即合决定以这个为题写一个专栏.其实所 ...

  2. shell常见脚本30例

    shell常见脚本30例 author:headsen chen  2017-10-19  10:12:12 本文原素材出自网上,特此申明.有些地方加入我自己的改动 常见的30例shell脚本 1.用 ...

  3. {MySQL的库、表的详细操作}一 库操作 二 表操作 三 行操作

    MySQL的库.表的详细操作 MySQL数据库 本节目录 一 库操作 二 表操作 三 行操作 一 库操作 1.创建数据库 1.1 语法 CREATE DATABASE 数据库名 charset utf ...

  4. 常见MIME类型例表

    常见MIME类型例表: 序号 内容类型 文件扩展名 描述 1 application/msword doc Microsoft Word 2 application/octet-stream bin ...

  5. 常见的装包的三种宝,包 bao-devel bao-utils bao-agent ,包 开发包 工具包 客户端

    常见的装包的三种宝,包  bao-devel    bao-utils   bao-agent  ,包    开发包   工具包  客户端

  6. VC中引用第三方库,常见的库冲突问题

    Q:VC中引用第三方库,常见的库冲突问题 环境:[1]VS2008 [2]WinXP SP3 A1(方法一): [S1]第三方库(Binary形式的)如果同主程序冲突,则下载第三方库的源码[S2]保持 ...

  7. TCP常见的定时器及三次握手与四次挥手

    1.TCP常见的定时器 在TCP协议中有的时候需要定期或者按照某个算法对某个事件进行触发,那么这个时候,TCP协议是使用定时器进行实现的.在TCP中,会有七种定时器: 建立连接定时器(connecti ...

  8. [WPF自定义控件库]以Button为例谈谈如何模仿Aero2主题

    1. 为什么选择Aero2 除了以外观为卖点的控件库,WPF的控件库都默认使用"素颜"的外观,然后再提供一些主题包.这样做的最大好处是可以和原生控件或其它控件库兼容,而且对于大部分 ...

  9. 关于阿里图标库Iconfont生成图标的三种使用方式(fontclass/unicode/symbol)

    1.附阿里图标库链接:http://www.iconfont.cn/ 2.登录阿里图标库以后,搜索我们需要的图标,将其加入购物车,如图3.将我们需要的图标全部挑选完毕以后,点击购物车图标4.这时候右侧 ...

随机推荐

  1. 小程序在ios10.2系统上兼容

    1.  定位元素在ios10.2系统上出现样式问题??? 没错,就是在测试在侧道ios10.2系统时发现了样式错误的问题,比如一个Swiper中,最后一个展示有问题. 这是啥原因❓❓❓❓❓❓ 大写的问 ...

  2. sql服务器第5级事务日志管理的阶梯:完全恢复模式下的日志管理

    sql服务器第5级事务日志管理的阶梯:完全恢复模式下的日志管理 原文链接http://www.sqlservercentral.com/articles/Stairway+Series/73785/ ...

  3. EOS2.0环境搭建-centos7

    需要安装启动的有三个组件 nodes,keosd,cleos,看看三者的关系 nodeos:核心程序,用于启动eos节点服务,在后台运行,可以配置不同 插件.该进程负责账户管理.区块生成.共识建立,并 ...

  4. VUE实现Studio管理后台(二):Slot实现选项卡tab切换效果,可自由填装内容

    作为RXEditor的主界面,Studio UI要使用大量的选项卡TAB切换,我梦想的TAB切换是可以自由填充内容的.可惜自己不会实现,只好在网上搜索一下,就跟现在你做的一样,看看有没有好事者实现了类 ...

  5. Springboot 2.2.x 默认不支持put、delete等请求方式

    springboot 2.2.x 默认不支持put delete等请 原因:springboot默认关闭了对它们的支持,只要在application.properties里面打开即可 spring.m ...

  6. 用table类型布局一个新闻网页

    <html><head><meta http-equiv="Content-Type" content="text/html; charse ...

  7. 数据挖掘入门系列教程(二)之分类问题OneR算法

    数据挖掘入门系列教程(二)之分类问题OneR算法 数据挖掘入门系列博客:https://www.cnblogs.com/xiaohuiduan/category/1661541.html 项目地址:G ...

  8. git删除远程仓库中的文件夹

    具体操作如下: git rm -r --cached .history    #删除目录 git commit -m”删除.history文件夹” git push -r表示递归所有子目录,如果你要删 ...

  9. python网络协议

    一 互联网的本质 咱们先不说互联网是如何通信的(发送数据,文件等),先用一个经典的例子,给大家说明什么是互联网通信. 现在追溯到八九十年代,当时电话刚刚兴起,还没有手机的概念,只是有线电话,那么此时你 ...

  10. Python之locust踩坑指北

    坑点1:locust安装报错 其中一种情况:error:Microsoft Visual C++ 14.0 is required. Get it with "Microsoft Visua ...