作者|huggingface

编译|VK

来源|Github

理念

Transformers是一个为NLP的研究人员寻求使用/研究/扩展大型Transformers模型的库。

该库的设计有两个强烈的目标:

  • 尽可能简单和快速使用:

    • 我们尽可能限制了要学习的面向对象抽象的类的数量,实际上几乎没有抽象,每个模型只需要使用三个标准类:配置、模型和tokenizer,
    • 所有这些类都可以通过使用公共的from_pretrained()实例化方法从预训练实例以简单统一的方式初始化,该方法将负责从库中下载,缓存和加载相关类提供的预训练模型或你自己保存的模型。
    • 因此,这个库不是构建神经网络模块的工具箱。如果您想扩展/构建这个库,只需使用常规的Python/PyTorch模块,并从这个库的基类继承,以重用诸如模型加载/保存等功能。
  • 提供最先进的模型与性能尽可能接近的原始模型:
    • 我们为每个架构提供了至少一个例子,该例子再现了上述架构的官方作者提供的结果
    • 代码通常尽可能地接近原始代码,这意味着一些PyTorch代码可能不那么pytorch化,因为这是转换TensorFlow代码后的结果。

其他几个目标:

  • 尽可能一致地暴露模型的内部:

    • 我们使用一个API来访问所有的隐藏状态和注意力权重,
    • 对tokenizer和基本模型的API进行了标准化,以方便在模型之间进行切换。
  • 结合一个主观选择的有前途的工具微调/调查这些模型:
    • 向词汇表和嵌入项添加新标记以进行微调的简单/一致的方法,
    • 简单的方法面具和修剪变压器头。

主要概念

该库是建立在三个类型的类为每个模型:

  • model类是目前在库中提供的8个模型架构的PyTorch模型(torch.nn.Modules),例如BertModel
  • configuration类,它存储构建模型所需的所有参数,例如BertConfig。您不必总是自己实例化这些配置,特别是如果您使用的是未经任何修改的预训练的模型,创建模型将自动负责实例化配置(它是模型的一部分)
  • tokenizer类,它存储每个模型的词汇表,并在要输送到模型的词汇嵌入索引列表中提供用于编码/解码字符串的方法,例如BertTokenizer

所有这些类都可以从预训练模型来实例化,并使用两种方法在本地保存:

  • from_pretraining()允许您从一个预训练版本实例化一个模型/配置/tokenizer,这个预训练版本可以由库本身提供(目前这里列出了27个模型),也可以由用户在本地(或服务器上)存储,
  • save_pretraining()允许您在本地保存模型/配置/tokenizer,以便可以使用from_pretraining()重新加载它。

我们将通过一些简单的快速启动示例来完成这个快速启动之旅,看看如何实例化和使用这些类。其余的文件分为两部分:

  • 主要的类详细介绍了三种主要类(配置、模型、tokenizer)的公共功能/方法/属性,以及一些作为训练工具提供的优化类,
  • 包引用部分详细描述了每个模型体系结构的每个类的所有变体,特别是调用它们时它们期望的输入和输出。

快速入门:使用

这里有两个例子展示了一些Bert和GPT2类以及预训练模型。

有关每个模型类的示例,请参阅完整的API参考。

BERT示例

让我们首先使用BertTokenizer从文本字符串准备一个标记化的输入(要输入给BERT的标记嵌入索引列表)

import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练的模型标记器(词汇表)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 标记输入
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text) # 用“BertForMaskedLM”掩盖我们试图预测的标记`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]'] # 将标记转换为词汇索引
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# 定义与第一句和第二句相关的句子A和B索引(见论文)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] # 将输入转换为PyTorch张量
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

让我们看看如何使用BertModel在隐藏状态下对输入进行编码:

# 加载预训练模型(权重)
model = BertModel.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') #预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
# Transformer模型总是输出元组。
# 有关所有输出的详细信息,请参见模型文档字符串。在我们的例子中,第一个元素是Bert模型最后一层的隐藏状态
encoded_layers = outputs[0]
# 我们已将输入序列编码为形状(批量大小、序列长度、模型隐藏维度)的FloatTensor
assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)

以及如何使用BertForMaskedLM预测屏蔽的标记:

# 加载预训练模型(权重)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0] # 确认我们能预测“henson”
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

OpenAI GPT-2

下面是一个快速开始的例子,使用GPT2TokenizerGPT2LMHeadModel类以及OpenAI的预训练模型来预测文本提示中的下一个标记。

首先,让我们使用GPT2Tokenizer

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # 编码输入
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text) # 转换为PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens])

让我们看看如何使用GPT2LMHeadModel生成下一个跟在我们的文本后面的token:

# 加载预训练模型(权重)
model = GPT2LMHeadModel.from_pretrained('gpt2') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0] # 得到预测的下一个子词(在我们的例子中,是“man”这个词)
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'

每个模型架构(Bert、GPT、GPT-2、Transformer XL、XLNet和XLM)的每个模型类的示例,可以在文档中找到。

使用过去的GPT-2

以及其他一些模型(GPT、XLNet、Transfo XL、CTRL),使用pastmems属性,这些属性可用于防止在使用顺序解码时重新计算键/值对。它在生成序列时很有用,因为注意力机制的很大一部分得益于以前的计算。

下面是一个使用带pastGPT2LMHeadModel和argmax解码的完整工作示例(只能作为示例,因为argmax decoding引入了大量重复):

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2') generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[..., -1, :]) generated += [token.tolist()]
context = token.unsqueeze(0) sequence = tokenizer.decode(generated) print(sequence)

由于以前所有标记的键/值对都包含在past,因此模型只需要一个标记作为输入。

Model2Model示例

编码器-解码器架构需要两个标记化输入:一个用于编码器,另一个用于解码器。假设我们想使用Model2Model进行生成性问答,从标记将输入模型的问答开始。

import torch
from transformers import BertTokenizer, Model2Model # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 编码输入(问题)
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question) # 编码输入(答案)
answer = "Jim Henson was a puppeteer"
encoded_answer = tokenizer.encode(answer) # 将输入转换为PyTorch张量
question_tensor = torch.tensor([encoded_question])
answer_tensor = torch.tensor([encoded_answer])

让我们看看如何使用Model2Model获取与此(问题,答案)对相关联的loss值:

#为了计算损失,我们需要向解码器提供语言模型标签(模型生成的标记id)。
lm_labels = encoded_answer
labels_tensor = torch.tensor([lm_labels]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
labels_tensor = labels_tensor.to('cuda')
model.to('cuda') # 预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(question_tensor, answer_tensor, decoder_lm_labels=labels_tensor)
# Transformers模型总是输出元组。
# 有关所有输出的详细信息,请参见models文档字符串
# 在我们的例子中,第一个元素是LM损失的值
lm_loss = outputs[0]

此损失可用于对Model2Model的问答任务进行微调。假设我们对模型进行了微调,现在让我们看看如何生成答案:

# 让我们重复前面的问题
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question)
question_tensor = torch.tensor([encoded_question]) # 这次我们试图生成答案,所以我们从一个空序列开始
answer = "[CLS]"
encoded_answer = tokenizer.encode(answer, add_special_tokens=False)
answer_tensor = torch.tensor([encoded_answer]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('fine-tuned-weights')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(question_tensor, answer_tensor)
predictions = outputs[0] # 确认我们能预测“jim”
predicted_index = torch.argmax(predictions[0, -1]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'jim'

欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:

http://pytorch.panchuang.net/

OpenCV中文官方文档:

http://woshicver.com/

Transformers 快速入门 | 一的更多相关文章

  1. Node.js快速入门

    Node.js是什么? Node.js是建立在谷歌Chrome的JavaScript引擎(V8引擎)的Web应用程序框架. 它的最新版本是:v0.12.7(在编写本教程时的版本).Node.js在官方 ...

  2. Web Api 入门实战 (快速入门+工具使用+不依赖IIS)

    平台之大势何人能挡? 带着你的Net飞奔吧!:http://www.cnblogs.com/dunitian/p/4822808.html 屁话我也就不多说了,什么简介的也省了,直接简单概括+demo ...

  3. SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=》提升)

     SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=>提升,5个Demo贯彻全篇,感兴趣的玩才是真的学) 官方demo:http://www.asp.net/si ...

  4. 前端开发小白必学技能—非关系数据库又像关系数据库的MongoDB快速入门命令(2)

    今天给大家道个歉,没有及时更新MongoDB快速入门的下篇,最近有点小忙,在此向博友们致歉.下面我将简单地说一下mongdb的一些基本命令以及我们日常开发过程中的一些问题.mongodb可以为我们提供 ...

  5. 【第三篇】ASP.NET MVC快速入门之安全策略(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  6. 【番外篇】ASP.NET MVC快速入门之免费jQuery控件库(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  7. Mybatis框架 的快速入门

    MyBatis 简介 什么是 MyBatis? MyBatis 是支持普通 SQL 查询,存储过程和高级映射的优秀持久层框架.MyBatis 消除 了几乎所有的 JDBC 代码和参数的手工设置以及结果 ...

  8. grunt快速入门

    快速入门 Grunt和 Grunt 插件是通过 npm 安装并管理的,npm是 Node.js 的包管理器. Grunt 0.4.x 必须配合Node.js >= 0.8.0版本使用.:奇数版本 ...

  9. 【第一篇】ASP.NET MVC快速入门之数据库操作(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

随机推荐

  1. USB小白学习之路(10) CY7C68013A Slave FIFO模式下的标志位(转)

    转自良子:http://www.eefocus.com/liangziusb/blog/12-11/288618_bdaf9.html CY7C68013含有4个大端点,可以用来处理数据量较大的传输, ...

  2. NOI Online 赛前刷题计划

    Day 1 模拟 链接:Day 1  模拟 题单:P1042 乒乓球  字符串 P1015 回文数  高精 + 进制 P1088 火星人  搜索 + 数论 P1604 B进制星球  高精 + 进制 D ...

  3. 前端每日实战:125# 视频演示如何用纯 CSS 创作一个失落的人独自行走的动画

    效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/MqpOdR/ 可交互视频 此视频是 ...

  4. CSS实现响应式布局

    用CSS实现响应式布局 响应式布局感觉很高大上,很难,但实际上只用CSS也能实现响应式布局要用的就是CSS中的没接查询,下面就介绍一下怎么运用: 使用@media 的三种方法 1.直接在CSS文件中使 ...

  5. JZOJ 5230. 【NOIP2017模拟A组模拟8.5】队伍统计

    5230. [NOIP2017模拟A组模拟8.5]队伍统计 (File IO): input:count.in output:count.out Time Limits: 1500 ms Memory ...

  6. Js逆向-滑动验证码图片还原

    本文列举两个例子:某象和某验的滑动验证 一.某验:aHR0cHM6Ly93d3cuZ2VldGVzdC5jb20vZGVtby9zbGlkZS1mbG9hdC5odG1s 未还原图像: 还原后的图: ...

  7. python之模块中包的介绍

    跨文件夹导入模块 1:有文件夹a,名下有ma功能,在文件夹外调用ma功能的话, 导入import a.ma 运用ma() 或者 from a import ma ma() 2;假定a有多重文件夹,想要 ...

  8. href="#"和href=“javascript:void(0)”的区别

    void是javascript中的关键字,该操作符指定要计算一个表达式但是不返回值. <a href="javascript:void(0);">点我没有反应的!< ...

  9. 第一个爬虫经历----豆瓣电影top250(经典案例)

    因为要学习数据分析,需要从网上爬取数据,所以开始学习爬虫,使用python进行爬虫,有好几种模拟发送请求的方法,最基础的是使用urllib.request模块(python自带,无需再下载),第二是r ...

  10. Java集合03——你不得不了解的Map

    Map 在面试中永远是一个绕不开的点,本文将详细讲解Map的相关内容.关注公众号「Java面典」了解更多 Java 知识点. Map Map 是一个键值对(key-value)映射接口: 映射中不能包 ...