聊聊HuggingFace Transformer
概述
项目组件
一个完整的transformer模型主要包含三部分:Config、Tokenizer、Model。
Config
用于配置模型的名称、最终输出的样式、隐藏层宽度和深度、激活函数的类别等。
示例:
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
Tokenizer
将纯文本转换为编码的过程(注意:该过程并不会生成词向量)。由于模型(Model)并不能识别(或很好的识别)文本数据,因此对于输入的文本需要做一层编码。在这个过程中,首先会将输入文本分词而后添加某些特殊标记([MASK]标记、[SEP]、[CLS]标记),比如断句等,最后就是转换为数字类型的ID(也可以理解为是字典索引)。
示例:
pt_batch = tokenizer(
["We are very happy to show you the Transformers library.",
"We hope you don't hate it."],
padding=True,
truncation=True,
max_length=5,
return_tensors="pt"
)
## 其中,当使用list作为batch进行输入时,使用到的参数注解如下:
## padding:填充,是否将所有句子pad到同一个长度。
## truncation:截断,当遇到超过max_length的句子时是否直接截断到max_length。
## return_tensors:张量返回值,"pt"表示返回pytorch类型的tensor,"tf"表示返回TensorFlow类型的tensor,"np"表示Numpy数组。
Model
AI模型(指代基于各种算法模型,比如预训练模型、深度学习算法、强化学习算法等的实现)的抽象概念。
除了初始的Bert
、GPT
等基本模型,针对下游任务,还定义了诸如BertForQuestionAnswering
等下游任务模型。
Transformer使用
pipeline的使用
transformer
库中最基本的对象是pipeline()
函数。它将模型与其必要的预处理和后处理步骤连接起来,使我们能够直接输入任何文本并获得可理解的答案:
from transformers import pipeline
classifier = pipeline("sentiment-analysis")
classifier("I've been waiting for a HuggingFace course my whole life.")
[{'label': 'POSITIVE', 'score': 0.9598047137260437}]
默认情况下,该pipeline
函数选择一个特定的预训练模型,该模型已经过英语情感分析的微调。当创建classifier
对象时,将下载并缓存模型。如果重新运行该命令,则将使用缓存的模型,并且不需要再次下载模型。
调用pipeline函数指定预训练模型,有三个主要步骤:
- 输入的文本被预处理成模型(Model)可以理解的格式的数据(就是上述中Tokenizer组件的处理过程)。
- 预处理后的数据作为输入参数传递给模型(Model)。
- 模型的预测结果(输出内容)是经过后处理的,可供理解。
目前可用的pipelines如下:
- feature-extraction(特征提取)
- fill-mask
- ner(命名实体识别)
- question-answering(自动问答)
- sentiment-analysis(情感分析)
- summarization(摘要)
- text-generation(文本生成)
- translation(翻译)
- zero-shot-classification(文本分类)
完整说明可参考:pipelines示例说明
pipeline的原理
如上所述,pipeline将三个步骤组合在一起:预处理、通过模型传递输入以及后处理:
Tokenizer的预处理
与其他神经网络一样,Transformer 模型无法直接处理原始文本,因此pipeline
的第一步是将文本输入转换为模型可以理解的数字。为此,我们使用分词器,它将负责:
- 将输入的文本分词,即拆分为单词、子单词或符号(如标点符号),这些被称为tokens(标记)。
- 将每个token映射到一个整数。
- 添加可能对模型有用的额外输入(微调)。
预训练模型完成后,所有的预处理需要完全相同的方式完成,因此我们首先需要从Model Hub
下载该信息。 为此,我们使用 AutoTokenizer
类及其 from_pretrained()
方法。 使用模型的checkpoint
,它将自动获取与模型的标记生成器关联的数据并缓存它。
由于情感分析pipeline
的checkpoint
是 distilbert-base-uncased-finetuned-sst-2-english
,因此我们运行以下命令:
from transformers import AutoTokenizer
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
如此便得到tokenizer
对象后,后续只需将文本参数输入即可,便完成了分词-编码-转换工作。
使用Transformers框架不需要担心使用哪个后端 ML 框架(PyTorch、TensorFlow、Flax)。Transformer 模型只接受tensors(张量)作为输入参数。
注:NumPy 数组可以是标量 (0D)、向量 (1D)、矩阵 (2D) 或具有更多维度。它实际上是一个张量。
tokenizer
中的return_tensors
参数定了返回的张量类型(PyTorch、TensorFlow 或普通 NumPy)
raw_inputs = [
"I've been waiting for a HuggingFace course my whole life.",
"I hate this so much!",
]
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")
print(inputs)
以下是tokenizer
返回的PyTorch张量的结果:
{
'input_ids': tensor([
[101, 1045, 1005, 2310, 2042, 3403, 2005, 1037, 17662, 12172, 2607, 2026, 2878, 2166, 1012, 102],
[101, 1045, 5223, 2023, 2061, 2172, 999, 102, 0, 0, 0, 0, 0, 0, 0, 0]
]),
'attention_mask': tensor([
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
])
}
tokenizer的返回值参数说明如下:
- 输出input_ids:经过编码后的数字(即前面所说的张量数据)。
- 输出token_type_ids:因为编码的是两个句子,这个list用于表明编码结果中哪些位置是第1个句子,哪些位置是第2个句子。具体表现为,第2个句子的位置是1,其他位置是0。
- 输出special_tokens_mask:用于表明编码结果中哪些位置是特殊符号,具体表现为,特殊符号的位置是1,其他位置是0。
- 输出attention_mask:用于表明编码结果中哪些位置是PAD。具体表现为,PAD的位置是0,其他位置是1。
- 输出length:表明编码后句子的长度。
Model层的处理
我们可以像使用tokenizer一样下载预训练模型。 Transformers 提供了一个 AutoModel
类,它也有一个 from_pretrained()
方法:
from transformers import AutoModel
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModel.from_pretrained(checkpoint)
## inputs的参数值是前面tokenizer的输出
outputs = model(**inputs)
与初始化tokenizer一样,将相同的checkpoint作为参数,初始化一个Model;而后将tokenizer的输出数据——张量数据作为参数输入到Model中。
模型的处理架构流程图,如下:
Transformer network模块有两层:嵌入层(Embeddings)、后续层(Layers)。嵌入层将标记化输入中的每个输入 ID 转换为表示关联标记的向量。 随后的层使用注意力机制操纵这些向量来产生句子的最终表示。
Transformer的输出,作为Hidden States,也可以理解为是Feature(特征数据)。
而这些特征数据,将作为模型的另一些部分的输入,比如Head层;最终由Head层输出模型的结果。
参考:
https://huggingface.co/learn/nlp-course/chapter2/1?fw=pt
聊聊HuggingFace Transformer的更多相关文章
- 昇思MindSpore全场景AI框架 1.6版本,更高的开发效率,更好地服务开发者
摘要:本文带大家快速浏览昇思MindSpore全场景AI框架1.6版本的关键特性. 全新的昇思MindSpore全场景AI框架1.6版本已发布,此版本中昇思MindSpore全场景AI框架易用性不断改 ...
- 利用Hugging Face中的模型进行句子相似性实践
Hugging Face是什么?它作为一个GitHub史上增长最快的AI项目,创始人将它的成功归功于弥补了科学与生产之间的鸿沟.什么意思呢?因为现在很多AI研究者写了大量的论文和开源了大量的代码, ...
- 大规模 Transformer 模型 8 比特矩阵乘简介 - 基于 Hugging Face Transformers、Accelerate 以及 bitsandbytes
引言 语言模型一直在变大.截至撰写本文时,PaLM 有 5400 亿参数,OPT.GPT-3 和 BLOOM 有大约 1760 亿参数,而且我们仍在继续朝着更大的模型发展.下图总结了最近的一些语言模型 ...
- 聊聊Unity项目管理的那些事:Git-flow和Unity
0x00 前言 目前所在的团队实行敏捷开发已经有了一段时间了.敏捷开发中重要的一个话题便是如何对项目进行恰当的版本管理.项目从最初使用svn到之后的Git One Track策略再到现在的GitFlo ...
- Mono为何能跨平台?聊聊CIL(MSIL)
前言: 其实小匹夫在U3D的开发中一直对U3D的跨平台能力很好奇.到底是什么原理使得U3D可以跨平台呢?后来发现了Mono的作用,并进一步了解到了CIL的存在.所以,作为一个对Unity3D跨平台能力 ...
- fir.im Weekly - 聊聊 Google 开发者大会
中国互联网的三大错觉:索尼倒闭,诺基亚崛起,谷歌重返中国.12月8日,2016 Google 开发者大会正式发布了Google Developers 中国网站 ,包含了Android Develope ...
- 聊聊asp.net中Web Api的使用
扯淡 随着app应用的崛起,后端服务开发的也越来越多,除了很多优秀的nodejs框架之外,微软当然也会在这个方面提供更便捷的开发方式.这是微软一贯的作风,如果从开发的便捷性来说的话微软是当之无愧的老大 ...
- 没有神话,聊聊decimal的“障眼法”
0x00 前言 在上一篇文章<妥协与取舍,解构C#中的小数运算>的留言区域有很多朋友都不约而同的说道了C#中的decimal类型.事实上之前的那篇文章的立意主要在于聊聊使用二进制的计算机是 ...
- 聊聊 C 语言中的 sizeof 运算
聊聊 sizeof 运算 在这两次的课上,同学们已经学到了数组了.下面几节课,应该就会学习到指针.这个速度的确是很快的. 对于同学们来说,暂时应该也有些概念理解起来可能会比较的吃力. 先说一个概念叫内 ...
- 聊聊 Apache 开源协议
摘要 用一句话概括 Apache License 就是,你可以用这代码,但是如果开源你必须保留我写的声明:你可以改我的代码,但是如果开源你必须写清楚你改了哪些:你可以加新的协议要求,但不能与我所 公布 ...
随机推荐
- drf-spectacular
介绍 drf-spectacular是为Django REST Framework生成合理灵活的OpenAPI 3.0模式.它可以自动帮我们提取接口中的信息,从而形成接口文档,而且内容十分详细,再也不 ...
- Gitlab Registries
在项目开发和部署过程中,我们常常需要一套私有仓库,比如 Code Repository.Package Repository,Docker Registry 等. Code Repository:在 ...
- HashMap 以及多线程基本感念
接口 Map :映射项,(键值对 ) 的容器注意: 键 是唯一的 值 是可以重复的实现类 HashMap :哈希表结构 允许使用null值 和 null 键 线程不安全 键唯一 无序 linkedHa ...
- RT_object
以下图片来自"张世争"的微博
- MySQL读取的记录和我想象的不一致
摘要:并发的事务在运行过程中会出现一些可能引发一致性问题的现象,本篇将详细分析一下. 本文分享自华为云社区<MySQL读取的记录和我想象的不一致--事物隔离级别和MVCC>,作者:砖业洋_ ...
- RLHF技术在智能金融中的应用:提高金融智能化和自动化水平”
目录 引言 随着人工智能技术的不断发展和普及,金融智能化和自动化水平也得到了显著提高.在这个时代,RLHF(Reinforcement Learning with Human Feedback)技术已 ...
- 详解在Linux中修改Tomcat使用的jdk版本
问题分析 由于部署个人项目使用了openjdk11,但是我之前安装的是jdk1.8,jdk版本升级的后果就是,tomcat运行的时候报一点小bug(因为之前安装tomcat默认使用了系统的jdk版本) ...
- hexo博客主题,git上传,报错Template render error的解决方案
报错信息 INFO Start processing FATAL Something's wrong. Maybe you can find the solution here: http://hex ...
- 说说 Linux 的 curl 命令
cURL,熟悉 Linux 的同学,没有人不知道这个命令吧:) 它有非常非常多的参数,我这里就不复制粘贴了,有需要可以 -h 或者谷歌搜索看看. 我从实用性的角度,说下我比较常用的几个参数: -v:啰 ...
- Win32API中的宽字符
4.1了解什么是Win32API Win32API就是windows操作系统提供给我们的函数(应用程序接口),其主要存放在C:\Windows\System32 (存储的DLL是64位).C:\Win ...