作者|huggingface

编译|VK

来源|Github

理念

Transformers是一个为NLP的研究人员寻求使用/研究/扩展大型Transformers模型的库。

该库的设计有两个强烈的目标:

  • 尽可能简单和快速使用:

    • 我们尽可能限制了要学习的面向对象抽象的类的数量,实际上几乎没有抽象,每个模型只需要使用三个标准类:配置、模型和tokenizer,
    • 所有这些类都可以通过使用公共的from_pretrained()实例化方法从预训练实例以简单统一的方式初始化,该方法将负责从库中下载,缓存和加载相关类提供的预训练模型或你自己保存的模型。
    • 因此,这个库不是构建神经网络模块的工具箱。如果您想扩展/构建这个库,只需使用常规的Python/PyTorch模块,并从这个库的基类继承,以重用诸如模型加载/保存等功能。
  • 提供最先进的模型与性能尽可能接近的原始模型:
    • 我们为每个架构提供了至少一个例子,该例子再现了上述架构的官方作者提供的结果
    • 代码通常尽可能地接近原始代码,这意味着一些PyTorch代码可能不那么pytorch化,因为这是转换TensorFlow代码后的结果。

其他几个目标:

  • 尽可能一致地暴露模型的内部:

    • 我们使用一个API来访问所有的隐藏状态和注意力权重,
    • 对tokenizer和基本模型的API进行了标准化,以方便在模型之间进行切换。
  • 结合一个主观选择的有前途的工具微调/调查这些模型:
    • 向词汇表和嵌入项添加新标记以进行微调的简单/一致的方法,
    • 简单的方法面具和修剪变压器头。

主要概念

该库是建立在三个类型的类为每个模型:

  • model类是目前在库中提供的8个模型架构的PyTorch模型(torch.nn.Modules),例如BertModel
  • configuration类,它存储构建模型所需的所有参数,例如BertConfig。您不必总是自己实例化这些配置,特别是如果您使用的是未经任何修改的预训练的模型,创建模型将自动负责实例化配置(它是模型的一部分)
  • tokenizer类,它存储每个模型的词汇表,并在要输送到模型的词汇嵌入索引列表中提供用于编码/解码字符串的方法,例如BertTokenizer

所有这些类都可以从预训练模型来实例化,并使用两种方法在本地保存:

  • from_pretraining()允许您从一个预训练版本实例化一个模型/配置/tokenizer,这个预训练版本可以由库本身提供(目前这里列出了27个模型),也可以由用户在本地(或服务器上)存储,
  • save_pretraining()允许您在本地保存模型/配置/tokenizer,以便可以使用from_pretraining()重新加载它。

我们将通过一些简单的快速启动示例来完成这个快速启动之旅,看看如何实例化和使用这些类。其余的文件分为两部分:

  • 主要的类详细介绍了三种主要类(配置、模型、tokenizer)的公共功能/方法/属性,以及一些作为训练工具提供的优化类,
  • 包引用部分详细描述了每个模型体系结构的每个类的所有变体,特别是调用它们时它们期望的输入和输出。

快速入门:使用

这里有两个例子展示了一些Bert和GPT2类以及预训练模型。

有关每个模型类的示例,请参阅完整的API参考。

BERT示例

让我们首先使用BertTokenizer从文本字符串准备一个标记化的输入(要输入给BERT的标记嵌入索引列表)

import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练的模型标记器(词汇表)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 标记输入
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text) # 用“BertForMaskedLM”掩盖我们试图预测的标记`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]'] # 将标记转换为词汇索引
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# 定义与第一句和第二句相关的句子A和B索引(见论文)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] # 将输入转换为PyTorch张量
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

让我们看看如何使用BertModel在隐藏状态下对输入进行编码:

# 加载预训练模型(权重)
model = BertModel.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') #预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
# Transformer模型总是输出元组。
# 有关所有输出的详细信息,请参见模型文档字符串。在我们的例子中,第一个元素是Bert模型最后一层的隐藏状态
encoded_layers = outputs[0]
# 我们已将输入序列编码为形状(批量大小、序列长度、模型隐藏维度)的FloatTensor
assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)

以及如何使用BertForMaskedLM预测屏蔽的标记:

# 加载预训练模型(权重)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0] # 确认我们能预测“henson”
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

OpenAI GPT-2

下面是一个快速开始的例子,使用GPT2TokenizerGPT2LMHeadModel类以及OpenAI的预训练模型来预测文本提示中的下一个标记。

首先,让我们使用GPT2Tokenizer

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # 编码输入
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text) # 转换为PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens])

让我们看看如何使用GPT2LMHeadModel生成下一个跟在我们的文本后面的token:

# 加载预训练模型(权重)
model = GPT2LMHeadModel.from_pretrained('gpt2') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0] # 得到预测的下一个子词(在我们的例子中,是“man”这个词)
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'

每个模型架构(Bert、GPT、GPT-2、Transformer XL、XLNet和XLM)的每个模型类的示例,可以在文档中找到。

使用过去的GPT-2

以及其他一些模型(GPT、XLNet、Transfo XL、CTRL),使用pastmems属性,这些属性可用于防止在使用顺序解码时重新计算键/值对。它在生成序列时很有用,因为注意力机制的很大一部分得益于以前的计算。

下面是一个使用带pastGPT2LMHeadModel和argmax解码的完整工作示例(只能作为示例,因为argmax decoding引入了大量重复):

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2') generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[..., -1, :]) generated += [token.tolist()]
context = token.unsqueeze(0) sequence = tokenizer.decode(generated) print(sequence)

由于以前所有标记的键/值对都包含在past,因此模型只需要一个标记作为输入。

Model2Model示例

编码器-解码器架构需要两个标记化输入:一个用于编码器,另一个用于解码器。假设我们想使用Model2Model进行生成性问答,从标记将输入模型的问答开始。

import torch
from transformers import BertTokenizer, Model2Model # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 编码输入(问题)
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question) # 编码输入(答案)
answer = "Jim Henson was a puppeteer"
encoded_answer = tokenizer.encode(answer) # 将输入转换为PyTorch张量
question_tensor = torch.tensor([encoded_question])
answer_tensor = torch.tensor([encoded_answer])

让我们看看如何使用Model2Model获取与此(问题,答案)对相关联的loss值:

#为了计算损失,我们需要向解码器提供语言模型标签(模型生成的标记id)。
lm_labels = encoded_answer
labels_tensor = torch.tensor([lm_labels]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
labels_tensor = labels_tensor.to('cuda')
model.to('cuda') # 预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(question_tensor, answer_tensor, decoder_lm_labels=labels_tensor)
# Transformers模型总是输出元组。
# 有关所有输出的详细信息,请参见models文档字符串
# 在我们的例子中,第一个元素是LM损失的值
lm_loss = outputs[0]

此损失可用于对Model2Model的问答任务进行微调。假设我们对模型进行了微调,现在让我们看看如何生成答案:

# 让我们重复前面的问题
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question)
question_tensor = torch.tensor([encoded_question]) # 这次我们试图生成答案,所以我们从一个空序列开始
answer = "[CLS]"
encoded_answer = tokenizer.encode(answer, add_special_tokens=False)
answer_tensor = torch.tensor([encoded_answer]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('fine-tuned-weights')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(question_tensor, answer_tensor)
predictions = outputs[0] # 确认我们能预测“jim”
predicted_index = torch.argmax(predictions[0, -1]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'jim'

欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:

http://pytorch.panchuang.net/

OpenCV中文官方文档:

http://woshicver.com/

Transformers 快速入门 | 一的更多相关文章

  1. Node.js快速入门

    Node.js是什么? Node.js是建立在谷歌Chrome的JavaScript引擎(V8引擎)的Web应用程序框架. 它的最新版本是:v0.12.7(在编写本教程时的版本).Node.js在官方 ...

  2. Web Api 入门实战 (快速入门+工具使用+不依赖IIS)

    平台之大势何人能挡? 带着你的Net飞奔吧!:http://www.cnblogs.com/dunitian/p/4822808.html 屁话我也就不多说了,什么简介的也省了,直接简单概括+demo ...

  3. SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=》提升)

     SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=>提升,5个Demo贯彻全篇,感兴趣的玩才是真的学) 官方demo:http://www.asp.net/si ...

  4. 前端开发小白必学技能—非关系数据库又像关系数据库的MongoDB快速入门命令(2)

    今天给大家道个歉,没有及时更新MongoDB快速入门的下篇,最近有点小忙,在此向博友们致歉.下面我将简单地说一下mongdb的一些基本命令以及我们日常开发过程中的一些问题.mongodb可以为我们提供 ...

  5. 【第三篇】ASP.NET MVC快速入门之安全策略(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  6. 【番外篇】ASP.NET MVC快速入门之免费jQuery控件库(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  7. Mybatis框架 的快速入门

    MyBatis 简介 什么是 MyBatis? MyBatis 是支持普通 SQL 查询,存储过程和高级映射的优秀持久层框架.MyBatis 消除 了几乎所有的 JDBC 代码和参数的手工设置以及结果 ...

  8. grunt快速入门

    快速入门 Grunt和 Grunt 插件是通过 npm 安装并管理的,npm是 Node.js 的包管理器. Grunt 0.4.x 必须配合Node.js >= 0.8.0版本使用.:奇数版本 ...

  9. 【第一篇】ASP.NET MVC快速入门之数据库操作(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

随机推荐

  1. python爬虫之selenium+打码平台识别验证码

    1.常用的打码平台:超级鹰.打码兔等 2.打码平台在识别图形验证码和点触验证码上比较好用 (1)12306点触验证码 from selenium import webdriver from selen ...

  2. python随用随学-元类

    python中的一切都是对象 按着我的逻辑走: 首先接受一个公理,「python中的一切都是对象」.不要问为什么,吉大爷(Guido van Rossum,python之父)人当初就是这么设计的,不服 ...

  3. 石油测井专题(六)MCM工艺在LWD的应用

    在上一篇的MCM工艺我们提到过石英挠性加速度计的伺服电路采用此工艺可以有效提高仪器产品的稳定性和寿命. MCM相对于印制电路板(PCB)来讲,MCM技术采用了更短的连接长度和更紧密的器件布局,从而降低 ...

  4. Vue双向绑定的实现原理系列(三):监听器Observer和订阅者Watcher

    监听器Observer和订阅者Watcher 实现简单版Vue的过程,主要实现{{}}.v-model和事件指令的功能 主要分为三个部分 github源码 1.数据监听器Observer,能够对数据对 ...

  5. EventEmitter:从命令式 JavaScript class 到声明函数式的华丽转身

    新书终于截稿,今天稍有空闲,为大家奉献一篇关于 JavaScript 语言风格的文章,主角是函数声明式. 灵活的 JavaScript 及其 multiparadigm 相信"函数式&quo ...

  6. C++读入输出优化

    读入输出优化虽然对于小数据没有半点作用,但是对于大数据来说,可以优化几十ms. 有时就是那么几十ms,可以被卡掉大数据的点 读入优化 int read() { int x=0,sig=1; char ...

  7. day05基本运算符,格式化输出,垃圾回收机制

    内容大纲:1.垃圾回收机制详解(了解) 引用计数 标记清除 分代回收 2.与用户交互 接收用户输入 # python3中 input # python2.7(了解) input raw_input 格 ...

  8. PyQt完整入门教程

    1.GUI开发框架简介 19年来,一直在做Android ROM相关测试,也有了一定的积累:20年,计划把之前完整的测试方案.脚本.工具进行整合复用. 第一期计划是开发一个GUI的测试工具,近期也进行 ...

  9. Python卸载

    前言 自己瞎折腾下载Python3.8.2,把之前下载好的python3.7.3覆盖掉.在运行之前Python环境的程序多次未果后.找到原因,Python3.7.3的包不支持Python3.8.2.于 ...

  10. SQL之开窗函数详解--可代替聚合函数使用

    在没学习开窗函数之前,我们都知道,用了分组之后,查询字段就只能是分组字段和聚合的字段,这带来了极大的不方便,有时我们查询时需要分组,又需要查询不分组的字段,每次都要又到子查询,这样显得sql语句复杂难 ...