使用 LoRA 和 Hugging Face 高效训练大语言模型
在本文中,我们将展示如何使用 大语言模型低秩适配 (Low-Rank Adaptation of Large Language Models,LoRA) 技术在单 GPU 上微调 110 亿参数的 FLAN-T5 XXL 模型。在此过程中,我们会使用到 Hugging Face 的 Transformers、Accelerate 和 PEFT 库。
通过本文,你会学到:
- 如何搭建开发环境
- 如何加载并准备数据集
- 如何使用 LoRA 和 bnb (即 bitsandbytes) int-8 微调 T5
- 如何评估 LoRA FLAN-T5 并将其用于推理
- 如何比较不同方案的性价比
另外,你可以 点击这里 在线查看此博文对应的 Jupyter Notebook。
快速入门: 轻量化微调 (Parameter Efficient Fine-Tuning,PEFT)
PEFT 是 Hugging Face 的一个新的开源库。使用 PEFT 库,无需微调模型的全部参数,即可高效地将预训练语言模型 (Pre-trained Language Model,PLM) 适配到各种下游应用。PEFT 目前支持以下几种方法:
- LoRA: LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS
- Prefix Tuning: P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks
- P-Tuning: GPT Understands, Too
- Prompt Tuning: The Power of Scale for Parameter-Efficient Prompt Tuning
注意: 本教程是在 g5.2xlarge AWS EC2 实例上创建和运行的,该实例包含 1 个 NVIDIA A10G。
1. 搭建开发环境
在本例中,我们使用 AWS 预置的 PyTorch 深度学习 AMI,其已安装了正确的 CUDA 驱动程序和 PyTorch。在此基础上,我们还需要安装一些 Hugging Face 库,包括 transformers 和 datasets。运行下面的代码就可安装所有需要的包。
# install Hugging Face Libraries
!pip install git+https://github.com/huggingface/peft.git
!pip install "transformers==4.27.1" "datasets==2.9.0" "accelerate==0.17.1" "evaluate==0.4.0" "bitsandbytes==0.37.1" loralib --upgrade --quiet
# install additional dependencies needed for training
!pip install rouge-score tensorboard py7zr
2. 加载并准备数据集
这里,我们使用 samsum 数据集,该数据集包含大约 16k 个含摘要的聊天类对话数据。这些对话由精通英语的语言学家制作。
{
"id": "13818513",
"summary": "Amanda baked cookies and will bring Jerry some tomorrow.",
"dialogue": "Amanda: I baked cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)"
}
我们使用 Datasets 库中的 load_dataset()
方法来加载 samsum
数据集。
from datasets import load_dataset
# Load dataset from the hub
dataset = load_dataset("samsum")
print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")
# Train dataset size: 14732
# Test dataset size: 819
为了训练模型,我们要用 Transformers Tokenizer 将输入文本转换为词元 ID。如果你需要了解这一方面的知识,请移步 Hugging Face 课程的 第 6 章。
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_id="google/flan-t5-xxl"
# Load tokenizer of FLAN-t5-XL
tokenizer = AutoTokenizer.from_pretrained(model_id)
在开始训练之前,我们还需要对数据进行预处理。生成式文本摘要属于文本生成任务。我们将文本输入给模型,模型会输出摘要。我们需要了解输入和输出文本的长度信息,以利于我们高效地批量处理这些数据。
from datasets import concatenate_datasets
import numpy as np
# The maximum total input sequence length after tokenization.
# Sequences longer than this will be truncated, sequences shorter will be padded.
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
input_lenghts = [len(x) for x in tokenized_inputs["input_ids"]]
# take 85 percentile of max length for better utilization
max_source_length = int(np.percentile(input_lenghts, 85))
print(f"Max source length: {max_source_length}")
# The maximum total sequence length for target text after tokenization.
# Sequences longer than this will be truncated, sequences shorter will be padded."
tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["summary"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
target_lenghts = [len(x) for x in tokenized_targets["input_ids"]]
# take 90 percentile of max length for better utilization
max_target_length = int(np.percentile(target_lenghts, 90))
print(f"Max target length: {max_target_length}")
我们将在训练前统一对数据集进行预处理并将预处理后的数据集保存到磁盘。你可以在本地机器或 CPU 上运行此步骤并将其上传到 Hugging Face Hub。
def preprocess_function(sample,padding="max_length"):
# add prefix to the input for t5
inputs = ["summarize: " + item for item in sample["dialogue"]]
# tokenize inputs
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
# Tokenize targets with the `text_target` keyword argument
labels = tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length":
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["dialogue", "summary", "id"])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")
# save datasets to disk for later easy loading
tokenized_dataset["train"].save_to_disk("data/train")
tokenized_dataset["test"].save_to_disk("data/eval")
3. 使用 LoRA 和 bnb int-8 微调 T5
除了 LoRA 技术,我们还使用 bitsanbytes LLM.int8() 把冻结的 LLM 量化为 int8。这使我们能够将 FLAN-T5 XXL 所需的内存降低到约四分之一。
训练的第一步是加载模型。我们使用 philschmid/flan-t5-xxl-sharded-fp16 模型,它是 google/flan-t5-xxl 的分片版。分片可以让我们在加载模型时不耗尽内存。
from transformers import AutoModelForSeq2SeqLM
# huggingface hub model id
model_id = "philschmid/flan-t5-xxl-sharded-fp16"
# load model from the hub
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
现在,我们可以使用 peft
为 LoRA int-8 训练作准备了。
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
# Define LoRA Config
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM
)
# prepare int-8 model for training
model = prepare_model_for_int8_training(model)
# add LoRA adaptor
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 18874368 || all params: 11154206720 || trainable%: 0.16921300163961817
如你所见,这里我们只训练了模型参数的 0.16%!这个巨大的内存增益让我们安心地微调模型,而不用担心内存问题。
接下来需要创建一个 DataCollator
,负责对输入和标签进行填充,我们使用 Transformers 库中的 DataCollatorForSeq2Seq
来完成这一环节。
from transformers import DataCollatorForSeq2Seq
# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8
)
最后一步是定义训练超参 ( TrainingArguments
)。
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
output_dir="lora-flan-t5-xxl"
# Define training args
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
auto_find_batch_size=True,
learning_rate=1e-3, # higher learning rate
num_train_epochs=5,
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=500,
save_strategy="no",
report_to="tensorboard",
)
# Create Trainer instance
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset["train"],
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
运行下面的代码,开始训练模型。请注意,对于 T5,出于收敛稳定性考量,某些层我们仍保持 float32
精度。
# train model
trainer.train()
训练耗时约 10 小时 36 分钟,训练 10 小时的成本约为 13.22 美元
。相比之下,如果 在 FLAN-T5-XXL 上进行全模型微调 10 个小时,我们需要 8 个 A100 40GB,成本约为 322 美元。
我们可以将模型保存下来以用于后面的推理和评估。我们暂时将其保存到磁盘,但你也可以使用 model.push_to_hub
方法将其上传到 Hugging Face Hub。
# Save our LoRA model & tokenizer results
peft_model_id="results"
trainer.model.save_pretrained(peft_model_id)
tokenizer.save_pretrained(peft_model_id)
# if you want to save the base model to call
# trainer.model.base_model.save_pretrained(peft_model_id)
最后生成的 LoRA checkpoint 文件很小,仅需 84MB 就包含了从 samsum
数据集上学到的所有知识。
4. 使用 LoRA FLAN-T5 进行评估和推理
我们将使用 evaluate
库来评估 rogue
分数。我们可以使用 PEFT
和 transformers
来对 FLAN-T5 XXL 模型进行推理。对 FLAN-T5 XXL 模型,我们至少需要 18GB 的 GPU 显存。
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load peft config for pre-trained checkpoint etc.
peft_model_id = "results"
config = PeftConfig.from_pretrained(peft_model_id)
# load base LLM model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map={"":0})
model.eval()
print("Peft model loaded")
我们用测试数据集中的一个随机样本来试试摘要效果。
from datasets import load_dataset
from random import randrange
# Load dataset from the hub and get a sample
dataset = load_dataset("samsum")
sample = dataset['test'][randrange(len(dataset["test"]))]
input_ids = tokenizer(sample["dialogue"], return_tensors="pt", truncation=True).input_ids.cuda()
# with torch.inference_mode():
outputs = model.generate(input_ids=input_ids, max_new_tokens=10, do_sample=True, top_p=0.9)
print(f"input sentence: {sample['dialogue']}\n{'---'* 20}")
print(f"summary:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]}")
不错!我们的模型有效!现在,让我们仔细看看,并使用 test
集中的全部数据对其进行评估。为此,我们需要实现一些工具函数来帮助生成摘要并将其与相应的参考摘要组合到一起。评估摘要任务最常用的指标是 rogue_score,它的全称是 Recall-Oriented Understudy for Gisting Evaluation。与常用的准确率指标不同,它将生成的摘要与一组参考摘要进行比较。
import evaluate
import numpy as np
from datasets import load_from_disk
from tqdm import tqdm
# Metric
metric = evaluate.load("rouge")
def evaluate_peft_model(sample,max_target_length=50):
# generate summary
outputs = model.generate(input_ids=sample["input_ids"].unsqueeze(0).cuda(), do_sample=True, top_p=0.9, max_new_tokens=max_target_length)
prediction = tokenizer.decode(outputs[0].detach().cpu().numpy(), skip_special_tokens=True)
# decode eval sample
# Replace -100 in the labels as we can't decode them.
labels = np.where(sample['labels']!= -100, sample['labels'], tokenizer.pad_token_id)
labels = tokenizer.decode(labels, skip_special_tokens=True)
# Some simple post-processing
return prediction, labels
# load test dataset from distk
test_dataset = load_from_disk("data/eval/").with_format("torch")
# run predictions
# this can take ~45 minutes
predictions, references = [], []
for sample in tqdm(test_dataset):
p,l = evaluate_peft_model(sample)
predictions.append(p)
references.append(l)
# compute metric
rogue = metric.compute(predictions=predictions, references=references, use_stemmer=True)
# print results
print(f"Rogue1: {rogue['rouge1']* 100:2f}%")
print(f"rouge2: {rogue['rouge2']* 100:2f}%")
print(f"rougeL: {rogue['rougeL']* 100:2f}%")
print(f"rougeLsum: {rogue['rougeLsum']* 100:2f}%")
# Rogue1: 50.386161%
# rouge2: 24.842412%
# rougeL: 41.370130%
# rougeLsum: 41.394230%
我们 PEFT 微调后的 FLAN-T5-XXL 在测试集上取得了 50.38%
的 rogue1 分数。相比之下,flan-t5-base 的全模型微调获得了 47.23 的 rouge1 分数。rouge1 分数提高了 3%
。
令人难以置信的是,我们的 LoRA checkpoint 只有 84MB,而且性能比对更小的模型进行全模型微调后的 checkpoint 更好。
你可以 点击这里 在线查看此博文对应的 Jupyter Notebook。
英文原文: https://www.philschmid.de/fine-tune-flan-t5-peft
原文作者:Philipp Schmid
译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。
排版/审校: zhongdongy (阿东)
使用 LoRA 和 Hugging Face 高效训练大语言模型的更多相关文章
- Java实现蓝桥杯 算法训练 大等于n的最小完全平方数
试题 算法训练 大等于n的最小完全平方数 资源限制 时间限制:1.0s 内存限制:256.0MB 问题描述 输出大等于n的最小的完全平方数. 若一个数能表示成某个自然数的平方的形式,则称这个数为完全平 ...
- Java高效读取大文件
1.概述 本教程将演示如何用Java高效地读取大文件.这篇文章是Baeldung (http://www.baeldung.com/) 上“Java——回归基础”系列教程的一部分. 2.在内存中读取 ...
- Java高效读取大文件(转)
1.概述 本教程将演示如何用Java高效地读取大文件.这篇文章是Baeldung(http://www.baeldung.com/) 上“Java——回归基础”系列教程的一部分. 2.在内存中读取 读 ...
- 图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用
1.ERNIESage运行实例介绍(1.8x版本) 本项目原链接:https://aistudio.baidu.com/aistudio/projectdetail/5097085?contribut ...
- Python:高效计算大文件中的最长行的长度
在操作某个很多进程都要频繁用到的大文件的时候,应该尽早释放文件资源(f.close()) 前2种方法主要用到了列表解析,性能稍差,而最后一种使用的时候生成器表达式,相比列表解析,更省内存 列表解析和生 ...
- 高效读取大文件,再也不用担心 OOM 了!
内存读取 第一个版本,采用内存读取的方式,所有的数据首先读读取到内存中,程序代码如下: Stopwatch stopwatch = Stopwatch.createStarted(); // 将全部行 ...
- 号外!号外!这个敏捷高效的大数据bi看板可以免费使用啦!
随着信息革命的深入推进,数据已经成为国家基础性战略资源,各个行业开始重视数据分析,企业不同,数据分析需求当然不一样,如销售行业需要对商品进行销售分析:网站运营需要进行用户.渠道.流量等信息分析:制造行 ...
- select2,利用ajax高效查询大数据列表(可搜索、可分页)
二.导入css和js到网站上 1.使用CDN,节省自己网站的流量 ? 1 2 <link href="https://cdnjs.cloudflare.com/ajax/libs/se ...
- word2vec——高效word特征提取
继上次分享了经典统计语言模型,最近公众号中有很多做NLP朋友问到了关于word2vec的相关内容, 本文就在这里整理一下做以分享. 本文分为 概括word2vec 相关工作 模型结构 Count-ba ...
- 顶尖大数据挖掘实战平台(TipDM-H8)产品白皮书
顶尖大数据挖掘实战平台 (TipDM-H8) 产 品 说 明 书 广州泰迪智能科技有限公司 版权所有 地址: 广州市经济技术开发区科学城232号 网址: http: ...
随机推荐
- javaweb本地启动很快,服务器上面启动特别慢
在JVM环境中解决 打开$JAVA_PATH/jre/lib/security/java.security这个文件,找到下面的内容: securerandom.source=file:/dev/ura ...
- layui.dtree的学习,自定义扩展toolbar按钮(toolbarExt)
学习layui.dtree请前往 http://www.wisdomelon.com/DTreeHelper/ 记录一下dtree的自定义扩展toolbar按钮(toolbarExt) html代码: ...
- MacOS 使用UnblockNeteaseMusic解锁网易云灰色歌曲(主要是想听杰伦)
最近想听杰伦的音乐 但是网易云木有版权 于是在github上找到了UnblockNeteaseMusic这个项目 不多废话 直接上教程! 第一步 找到该项目的地址 并使用git克隆到本地: https ...
- 关于 indy Idhttp Post数据抛异常 connection closed Gracefully
delphi 使用indy -idHttp 控件post 数据时 会报connection closed Gracefully这个异常的问题. 网上找了很多方法最多的就是 修改: MyHttp.Req ...
- 杨辉三角形实现过程详解-C语言基础
这一篇要探讨的是"杨辉三角形的实现以及如何人工走循环".涉及的知识点和内容很少,主要是想说明如何看懂循环,如何跟着循环走.属于C语言基础篇. 学习编程的人,在学习的初期,几乎都会接 ...
- JVM调优学习笔记
TODO:需要学习的命令 jps jstat -gcutil pid xxxx jmap histo:live pid
- mac安装mysql5.6默认密码修改
前言: 每次安装mysql都被烦的要死,痛并思痛记下此篇文章: 参考: https://blog.csdn.net/u010377383/article/details/82688250 https: ...
- webservice学习随笔(一):简单的webservice实例
一.webService概念简单介绍: 简单来说,webservice就是远程调用技术,也叫XML Web Service WebService是一种可以接收从Internet或者Intranet上的 ...
- jenkins freestyle deploy web
gitlab connection 选择定义好的gitlab仓库 参数化构建过程 git参数 名称 branch 描述 自定义 参数类型 分支 默认值 $branch 选项参数 名称 Status 选 ...
- win8 改win7 最全教程(包含可能遇到的所有问题)
今日,帮一个朋友的把她的系统从win8 优雅降级到了win7,大家都知道win8改win7 不好改啊.......话不多,上本人的总结的教程. 首先 ,win8改win7 需要对系统格盘,这里的原因我 ...