Transformers Pipelines
pipelines 是使用模型进行推理的一种很好且简单的方法。这些pipelines 是从库中抽象出大部分复杂代码的对象,提供了一个简单的API,专门用于多个任务,包括命名实体识别、屏蔽语言建模、情感分析、特征提取和问答等。
参数说明
初始化pipeline时可能的参数:
task (str
) — 定义pipeline需要返回的任务。
model (str
or PreTrainedModel or TFPreTrainedModel, optional) — 拟使用的模型,有时可以只指定模型,不指定task
config (str
or PretrainedConfig, optional) — 实例化模型的配置。取值可以是一个模型标志符(模型名称),也可以是利用PretrainedConfig
继承得来
tokenizer (str
or PreTrainedTokenizer, optional) — 用于编码模型中的数据。取值可以是一个模型标志符(模型名称),也可以是利用 PreTrainedTokenizer
继承得来
feature_extractor (str
or PreTrainedFeatureExtractor
, optional) — 特征提取器
framework (str
, optional) — 指明运行模型的框架,要么是"pt"(表示pytorch), 要么是"tf"(表示tensorflow)
revision (str
, optional, defaults to "main"
) — 指定所加载模型的版本
use_fast (bool
, optional, defaults to True
) — 如果可以的话(a PreTrainedTokenizerFast),是否使用Fast tokenizer
use_auth_token (str
or bool, optional) — 是否需要认证
device (int
or str
or torch.device
) — 指定运行模型的硬件设备。(例如:"cpu","cuda:1","mps",或者是一个GPU的编号,比如 1)
device_map (str
or Dict[str, Union[int, str, torch.device]
, optional) — Sent directly as model_kwargs
(just a simpler shortcut). When accelerate
library is present, set device_map="auto"
to compute the most optimized device_map
automatically. More information
torch_dtype (str
or torch.dtype
, optional) — 指定模型可用的精度。sent directly as model_kwargs
(just a simpler shortcut) to use the available precision for this model (torch.float16
, torch.bfloat16
, … or "auto"
).
trust_remote_code (bool
, optional, defaults to False
) —
使用pipeline对象处理数据可能的参数:
batch_size (int
) — 数据处理的批次大小
truncation (bool
, optional, defaults to False
) — 是否截断
padding (bool
, optional, defaults to False
) — 是否padding
实例:
初始化一个文本分类其的pipeline 对象
from transformers import pipeline
classifier = pipeline(task="text-classification") #
模型的输入inputs(可以是一个字典、列表、单个字符串)。
inputs = "嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻"
使用pipeline对象处理数据
results = classifier(inputs, truncation=True, padding=True, max_length=512):
批处理的使用建议
- 在有延迟限制的实时任务中, 别用批处理
- 使用CPU进行预测时,别用批处理
- 如果您不知道sequence_length的大小(例如自然数据),别用批处理。设置OOM检查,以便于过长输入序列导致模型执行异常时,模型可以自动重启
- 如果输入中包含100个样本序列,仅一个样本序列长度是600,其余长度为4,那么当它们作为一个批次输入时,输入数据的shape也仍是(100, 600)。
- 如果样本sequence_length比较规整,则建议使用尽可能大的批次。
- 总之,使用批处理需要处理更好溢出问题
Pipelines 主要包括三大模块(以TextClassificationPipeline
为例)
数据预处理:tokenize 未处理的数据。对应pipeline中的preprocess()
前向计算:模型进行计算。对应pipeline中的_forward()
后处理:对模型计算出来的logits
进行进一步处理。对应pipeline中的postprocess()
PS: 继承pipeline类(如TextClassificationPipeline
),并重写以上的三个函数,可以实现自定义pipeline
实例(修改模型预测的标签):此处以修改模型预测标签为例,重写后处理过程postprocess()
1、导库,并读取提前准备好的标签映射数据
# coding=utf-8
from transformers import pipeline
from transformers import TextClassificationPipeline
import numpy as np
import json
import pandas as pd
with open('model_save_epochs100_batch1/labelmap.json')as fr:
id2label = {ind: label for ind, label in enumerate(json.load(fr).values())}
2、自定义pipeline(继承TextClassificationPipeline
,并重写postprocess()
)
class CustomTextClassificationPipeline(TextClassificationPipeline):
def sigmoid_(self, _outputs):
return 1.0 / (1.0 + np.exp(-_outputs))
def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
outputs = model_outputs["logits"][0] # 感觉这里每次只会返回一个样本的计算结果
outputs = outputs.numpy()
scores = self.sigmoid_(outputs)
dict_scores = [
{"label": id2label[i], "score": score.item()} for i, score in enumerate(scores) if score > 0.5
]
return dict_scores
3、使用自定义的pipeline实例化一个分类器(有两种方式)
方式一:将自定义类名传参给pipeline_class
classifier = pipeline(model='model_save_epochs100_batch1/checkpoint-325809',
pipeline_class=CustomTextClassificationPipeline,
task="text-classification",
function_to_apply='sigmoid', top_k=10, device=0) # return_all_scores=True
方式二:直接使用自定义类创建分类器
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained('model_save_epochs100_batch1/checkpoint-325809')
classifier = CustomTextClassificationPipeline(model=model, # 此处model的值得是加载好了得模型,不能是一个字符串
pipeline_class=CustomTextClassificationPipeline,
task="text-classification",
function_to_apply='sigmoid', top_k=10, device=0) # return_all_scores=True
4、利用分类器对文本数据进行预测
# 读取待测数据
with open('raw_data/diseasecontent.json', 'r', encoding='utf-8') as fr:
texts = [text.strip() for text in json.load(fr)if text.strip("000").strip()]
res = []
for text in texts:
labels = []
for ite in classifier(text, truncation=True, max_length=512): # padding=True, 执行预测
labels.append(ite['label'])
res.append({'text': text, 'labels': labels})
# 保存预测结果
df = pd.DataFrame(res)
df.to_excel('model_save_epochs100_batch1/test_res.xlsx')
Transformers Pipelines的更多相关文章
- Spark2.0 Pipelines
MLlib中众多机器学习算法API在单一管道或工作流中更容易相互结合起来使用.管道的思想主要是受到scikit-learn库的启发. ML API使用Spark SQL中的DataFrame作为机器学 ...
- kaggle Pipelines
# Most scikit-learn objects are either transformers or models. # Transformers are for pre-processing ...
- ML Pipelines管道
ML Pipelines管道 In this section, we introduce the concept of ML Pipelines. ML Pipelines provide a uni ...
- Nancy之Pipelines三兄弟(Before After OnError)
一.简单描述 Before:如果返回null,拦截器将主动权转给路由:如果返回Response对象,则路由不起作用. After : 没有返回值,可以在这里修改或替换当前的Response. OnEr ...
- Coax Transformers[转载]
Coax Transformers How to determine the needed Z for a wanted Quarter Wave Lines tranformation ratio ...
- 【最短路】ACdream 1198 - Transformers' Mission
Problem Description A group of transformers whose leader is Optimus Prime(擎天柱) were assigned a missi ...
- Linux - 命令行 管道(Pipelines) 详细解释
命令行 管道(Pipelines) 详细解释 本文地址: http://blog.csdn.net/caroline_wendy/article/details/24249529 管道操作符" ...
- 使用 Bitbucket Pipelines 持续交付托管项目
简介 Bitbucket Pipelines 是Atlassian公司为Bitbucket Cloud产品添加的一个新功能, 它为托管在Bitbucket上的项目提供了一个良好的持续集成/交付的服务. ...
- Easy machine learning pipelines with pipelearner: intro and call for contributors
@drsimonj here to introduce pipelearner – a package I'm developing to make it easy to create machine ...
- sql hibernate查询转换成实体或对应的VO Transformers
sql查询转换成实体或对应的VO Transformers //addScalar("id") 默认查询出来的id是全部大写的(sql起别名也无效,所以使用.addScalar(& ...
随机推荐
- tornado原理介绍及异步非阻塞实现方式
tornado原理介绍及异步非阻塞实现方式 以下内容根据自己实操和理解进行的整理,欢迎交流~ 在tornado的开发中,我们一般会见到以下四个组成部分. ioloop: 同一个ioloop实例运行在一 ...
- python之路22 hashlib、subprocess、logging模块
hashlib加密模块 hashlib模块为不同的安全哈希/安全散列(Secure Hash Algorithm)和 信息摘要算法(Message Digest Algorithm)实现了一个公共的. ...
- 10分钟做好 Bootstrap Blazor 的表格组件导出 Excel/Word/Html/Pdf
上篇: Bootstrap Blazor 实战 通用导入导出服务(Table组件) 1.新建工程 新建工程b14table dotnet new blazorserver -o b14table 将项 ...
- [Unity]限制两个物体之间的距离
//限制两个物体之间的距离 if (Vector3.Distance(B.position, A.position) > maxDistance) { //获得两个物体之间的单位向量 Vecto ...
- SPOJLCMSUM - LCM Sum
简要题意 \(T\) 组数据,每组数据给出一个 \(n\),计算: \[\sum_{i=1}^{n}{\operatorname{lcm}(i,n)} \] \(1 \leq T \leq 3\tim ...
- C# 线程查漏补缺
进程和线程 不同程序执行需要进行调度和独立的内存空间 在单核计算机中,CPU 是独占的,内存是共享的,这时候运行一个程序的时候是没有问题.但是运行多个程序的时候,为了不发生一个程序霸占整个 CPU 不 ...
- 最新编程语言排名Python、C、Java 和 C++ 已形成四足鼎立之势
引言 技术的千变万化,都是有迹可循的,随着最新的 TIOBE 十月编程语言榜单重磅发布,不同开发语言的排名和发展趋势也随之揭晓! 四大编程语言不断增强其主导地位 曾几何时,编程语言界中 Java.C. ...
- ua5.4源码剖析:三. C++与Lua相互调用
概述 从本质上来看,其实说是不存在所谓的C++与lua的相互调用.lua是运行在C上的,简单来说lua的代码会被编译成字节码在被C语言的语法运行.在C++调用lua时,其实是解释运行lua文件编译出来 ...
- SSM中PageHelper的使用方法
SSM中PageHelper的使用方法 转载于for dream 第一步.导包(或者导入坐标) <!-- https://mvnrepository.com/artifact/com.githu ...
- KEIL5、STM32CubeMX、STM32CubeIDE 下载、安装
一.资源下载 Keil5下载链接: https://www.keil.com/download/product/ STM32 标准库芯片包下载链接: https://www.keil.com/dd2/ ...