作者|huggingface

编译|VK

来源|Github

理念

Transformers是一个为NLP的研究人员寻求使用/研究/扩展大型Transformers模型的库。

该库的设计有两个强烈的目标:

  • 尽可能简单和快速使用:

    • 我们尽可能限制了要学习的面向对象抽象的类的数量,实际上几乎没有抽象,每个模型只需要使用三个标准类:配置、模型和tokenizer,
    • 所有这些类都可以通过使用公共的from_pretrained()实例化方法从预训练实例以简单统一的方式初始化,该方法将负责从库中下载,缓存和加载相关类提供的预训练模型或你自己保存的模型。
    • 因此,这个库不是构建神经网络模块的工具箱。如果您想扩展/构建这个库,只需使用常规的Python/PyTorch模块,并从这个库的基类继承,以重用诸如模型加载/保存等功能。
  • 提供最先进的模型与性能尽可能接近的原始模型:
    • 我们为每个架构提供了至少一个例子,该例子再现了上述架构的官方作者提供的结果
    • 代码通常尽可能地接近原始代码,这意味着一些PyTorch代码可能不那么pytorch化,因为这是转换TensorFlow代码后的结果。

其他几个目标:

  • 尽可能一致地暴露模型的内部:

    • 我们使用一个API来访问所有的隐藏状态和注意力权重,
    • 对tokenizer和基本模型的API进行了标准化,以方便在模型之间进行切换。
  • 结合一个主观选择的有前途的工具微调/调查这些模型:
    • 向词汇表和嵌入项添加新标记以进行微调的简单/一致的方法,
    • 简单的方法面具和修剪变压器头。

主要概念

该库是建立在三个类型的类为每个模型:

  • model类是目前在库中提供的8个模型架构的PyTorch模型(torch.nn.Modules),例如BertModel
  • configuration类,它存储构建模型所需的所有参数,例如BertConfig。您不必总是自己实例化这些配置,特别是如果您使用的是未经任何修改的预训练的模型,创建模型将自动负责实例化配置(它是模型的一部分)
  • tokenizer类,它存储每个模型的词汇表,并在要输送到模型的词汇嵌入索引列表中提供用于编码/解码字符串的方法,例如BertTokenizer

所有这些类都可以从预训练模型来实例化,并使用两种方法在本地保存:

  • from_pretraining()允许您从一个预训练版本实例化一个模型/配置/tokenizer,这个预训练版本可以由库本身提供(目前这里列出了27个模型),也可以由用户在本地(或服务器上)存储,
  • save_pretraining()允许您在本地保存模型/配置/tokenizer,以便可以使用from_pretraining()重新加载它。

我们将通过一些简单的快速启动示例来完成这个快速启动之旅,看看如何实例化和使用这些类。其余的文件分为两部分:

  • 主要的类详细介绍了三种主要类(配置、模型、tokenizer)的公共功能/方法/属性,以及一些作为训练工具提供的优化类,
  • 包引用部分详细描述了每个模型体系结构的每个类的所有变体,特别是调用它们时它们期望的输入和输出。

快速入门:使用

这里有两个例子展示了一些Bert和GPT2类以及预训练模型。

有关每个模型类的示例,请参阅完整的API参考。

BERT示例

让我们首先使用BertTokenizer从文本字符串准备一个标记化的输入(要输入给BERT的标记嵌入索引列表)

import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练的模型标记器(词汇表)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 标记输入
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text) # 用“BertForMaskedLM”掩盖我们试图预测的标记`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]'] # 将标记转换为词汇索引
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# 定义与第一句和第二句相关的句子A和B索引(见论文)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] # 将输入转换为PyTorch张量
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

让我们看看如何使用BertModel在隐藏状态下对输入进行编码:

# 加载预训练模型(权重)
model = BertModel.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') #预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
# Transformer模型总是输出元组。
# 有关所有输出的详细信息,请参见模型文档字符串。在我们的例子中,第一个元素是Bert模型最后一层的隐藏状态
encoded_layers = outputs[0]
# 我们已将输入序列编码为形状(批量大小、序列长度、模型隐藏维度)的FloatTensor
assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)

以及如何使用BertForMaskedLM预测屏蔽的标记:

# 加载预训练模型(权重)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0] # 确认我们能预测“henson”
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

OpenAI GPT-2

下面是一个快速开始的例子,使用GPT2TokenizerGPT2LMHeadModel类以及OpenAI的预训练模型来预测文本提示中的下一个标记。

首先,让我们使用GPT2Tokenizer

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # 编码输入
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text) # 转换为PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens])

让我们看看如何使用GPT2LMHeadModel生成下一个跟在我们的文本后面的token:

# 加载预训练模型(权重)
model = GPT2LMHeadModel.from_pretrained('gpt2') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0] # 得到预测的下一个子词(在我们的例子中,是“man”这个词)
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'

每个模型架构(Bert、GPT、GPT-2、Transformer XL、XLNet和XLM)的每个模型类的示例,可以在文档中找到。

使用过去的GPT-2

以及其他一些模型(GPT、XLNet、Transfo XL、CTRL),使用pastmems属性,这些属性可用于防止在使用顺序解码时重新计算键/值对。它在生成序列时很有用,因为注意力机制的很大一部分得益于以前的计算。

下面是一个使用带pastGPT2LMHeadModel和argmax解码的完整工作示例(只能作为示例,因为argmax decoding引入了大量重复):

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2') generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[..., -1, :]) generated += [token.tolist()]
context = token.unsqueeze(0) sequence = tokenizer.decode(generated) print(sequence)

由于以前所有标记的键/值对都包含在past,因此模型只需要一个标记作为输入。

Model2Model示例

编码器-解码器架构需要两个标记化输入:一个用于编码器,另一个用于解码器。假设我们想使用Model2Model进行生成性问答,从标记将输入模型的问答开始。

import torch
from transformers import BertTokenizer, Model2Model # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 编码输入(问题)
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question) # 编码输入(答案)
answer = "Jim Henson was a puppeteer"
encoded_answer = tokenizer.encode(answer) # 将输入转换为PyTorch张量
question_tensor = torch.tensor([encoded_question])
answer_tensor = torch.tensor([encoded_answer])

让我们看看如何使用Model2Model获取与此(问题,答案)对相关联的loss值:

#为了计算损失,我们需要向解码器提供语言模型标签(模型生成的标记id)。
lm_labels = encoded_answer
labels_tensor = torch.tensor([lm_labels]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
labels_tensor = labels_tensor.to('cuda')
model.to('cuda') # 预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(question_tensor, answer_tensor, decoder_lm_labels=labels_tensor)
# Transformers模型总是输出元组。
# 有关所有输出的详细信息,请参见models文档字符串
# 在我们的例子中,第一个元素是LM损失的值
lm_loss = outputs[0]

此损失可用于对Model2Model的问答任务进行微调。假设我们对模型进行了微调,现在让我们看看如何生成答案:

# 让我们重复前面的问题
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question)
question_tensor = torch.tensor([encoded_question]) # 这次我们试图生成答案,所以我们从一个空序列开始
answer = "[CLS]"
encoded_answer = tokenizer.encode(answer, add_special_tokens=False)
answer_tensor = torch.tensor([encoded_answer]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('fine-tuned-weights')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(question_tensor, answer_tensor)
predictions = outputs[0] # 确认我们能预测“jim”
predicted_index = torch.argmax(predictions[0, -1]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'jim'

欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:

http://pytorch.panchuang.net/

OpenCV中文官方文档:

http://woshicver.com/

Transformers 快速入门 | 一的更多相关文章

  1. Node.js快速入门

    Node.js是什么? Node.js是建立在谷歌Chrome的JavaScript引擎(V8引擎)的Web应用程序框架. 它的最新版本是:v0.12.7(在编写本教程时的版本).Node.js在官方 ...

  2. Web Api 入门实战 (快速入门+工具使用+不依赖IIS)

    平台之大势何人能挡? 带着你的Net飞奔吧!:http://www.cnblogs.com/dunitian/p/4822808.html 屁话我也就不多说了,什么简介的也省了,直接简单概括+demo ...

  3. SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=》提升)

     SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=>提升,5个Demo贯彻全篇,感兴趣的玩才是真的学) 官方demo:http://www.asp.net/si ...

  4. 前端开发小白必学技能—非关系数据库又像关系数据库的MongoDB快速入门命令(2)

    今天给大家道个歉,没有及时更新MongoDB快速入门的下篇,最近有点小忙,在此向博友们致歉.下面我将简单地说一下mongdb的一些基本命令以及我们日常开发过程中的一些问题.mongodb可以为我们提供 ...

  5. 【第三篇】ASP.NET MVC快速入门之安全策略(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  6. 【番外篇】ASP.NET MVC快速入门之免费jQuery控件库(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  7. Mybatis框架 的快速入门

    MyBatis 简介 什么是 MyBatis? MyBatis 是支持普通 SQL 查询,存储过程和高级映射的优秀持久层框架.MyBatis 消除 了几乎所有的 JDBC 代码和参数的手工设置以及结果 ...

  8. grunt快速入门

    快速入门 Grunt和 Grunt 插件是通过 npm 安装并管理的,npm是 Node.js 的包管理器. Grunt 0.4.x 必须配合Node.js >= 0.8.0版本使用.:奇数版本 ...

  9. 【第一篇】ASP.NET MVC快速入门之数据库操作(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

随机推荐

  1. Oracle中的列转行实现字段拼接用例

    文章目录 Oracle中的列转行实现字段拼接 场景 在SQL使用过程中经常有这种需求:将某列字段拼接成in('XX','XX','XX','XX','XX','XX' ...)做为查询条件. 实现 s ...

  2. 沈向洋|微软携手 OpenAI 进一步履行普及且全民化人工智能的使命

    OpenAI 进一步履行普及且全民化人工智能的使命"> 作者简介 沈向洋,微软全球执行副总裁,微软人工智能及微软研究事业部负责人 我们正处于技术发展历程中的关键时刻. 云计算的强大计算 ...

  3. C#使用正则表达式获取HTML代码中a标签里包含指定后缀的href的值

    //C#使用正则表达式获取HTML代码中a标签里包含指定后缀的href的值,表达式如下: Regex regImg = new Regex(@"(?is)<a[^>]*?href ...

  4. 一起了解 .Net Foundation 项目 No.9

    .Net 基金会中包含有很多优秀的项目,今天就和笔者一起了解一下其中的一些优秀作品吧. 中文介绍 中文介绍内容翻译自英文介绍,主要采用意译.如与原文存在出入,请以原文为准. DLR/IronPytho ...

  5. 聊一聊MyBatis 和 SQL 注入间的恩恩怨怨

    整理了一些Java方面的架构.面试资料(微服务.集群.分布式.中间件等),有需要的小伙伴可以关注公众号[程序员内点事],无套路自行领取 更多优选 一口气说出 9种 分布式ID生成方式,面试官有点懵了 ...

  6. PhaserJS 3 屏幕适配时的小坑 -- JavaScript Html5 游戏开发

    巨坑:在config内不要把 width 设为 window.innnerWidth在config内不要把 width 设为 window.innnerWidth在config内不要把 width 设 ...

  7. python如何在图片上添加文字(中文和英文)

    Python在图片上添加文字的两种方法:OpenCV和PIL 一.OpenCV方法 1.安装cv2 pip install opencv-python 2.利用putText方法来实现在图片的指定位置 ...

  8. centos7搭建ceph集群

    一.服务器规划 主机名 主机IP 磁盘配比 角色 node1 public-ip:10.0.0.130cluster-ip:192.168.2.130 sda,sdb,sdcsda是系统盘,另外两块数 ...

  9. Python爬虫 - UserAgent列表

    PC端: PC_USER_AGENT = [ 'Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1)', 'Mozilla/4.0 (compatibl ...

  10. W3C的盒子模型和IE的盒子模型

    盒子模型分为两种:W3C盒子模型(标准盒子模型)和IE盒子模型 盒子模型组成:content+padding+border+margin 标准盒子模型的width就是content 而IE盒子模型的w ...