这篇博客是一篇来自 Meta AI,关于指令微调 Llama 2 的扩展说明。旨在聚焦构建指令数据集,有了它,我们则可以使用自己的指令来微调 Llama 2 基础模型。

目标是构建一个能够基于输入内容来生成指令的模型。这么做背后的逻辑是,模型如此就可以由其他人生成自己的指令数据集。这在当想开发私人个性化定制模型,如发送推特、写邮件等,时很方便。这也意味着你可以通过你的邮件来生成一个指令数据集,然后用它来训练一个模型来为你写邮件。

好,那我们来开始吧?我们将进行:

  1. 定义应用场景细节并创建指令的提示词模板
  2. 构建指令数据集
  3. 使用 trlSFTTrainer 指令微调 Llama 2
  4. 测试模型、进行推理

1. 定义应用场景细节并创建指令的提示词模板

在描述应用场景前,我们要更好的理解一下究竟什么是指令。

指令是一段文本或提供给大语言模型,类似 Llama,GPT-4 或 Claude,使用的提示词,用来指导它去生成回复。指令可以让人们做到把控对话,约束模型输出更自然、实用的输出,并使这些结果能够对齐用户的目的。制作清晰的、整洁的指令则是生成高质量对话的关键。

指令的例子如下表所示。

能力 示例指令
头脑风暴 提供一系列新口味的冰淇淋的创意。
分类 根据剧情概要,将这些电影归类为喜剧、戏剧或恐怖片。
确定性问答 用一个单词回答“法国的首都是哪里?”
生成 用罗伯特·弗罗斯特的风格写一首关于大自然和季节变化的诗。
信息提取 从这篇短文中提取主要人物的名字。
开放性问答 为什么树叶在秋天会变色?用科学的理由解释一下。
摘要 用 2-3 句话概括一下这篇关于可再生能源最新进展的文章。

如开头所述,我们想要微调模型,以便根据输入 (或输出) 生成指令。 我们希望将其用作创建合成数据集的方法,以赋予 LLM 和代理个性化能力。

把这个想法转换成一个基础的提示模板,按照 Alpaca 格式.

  1. ### Instruction:
  2. Use the Input below to create an instruction, which could have been used to generate the input using an LLM.
  3. ### Input:
  4. Dear [boss name],
  5. I'm writing to request next week, August 1st through August 4th,
  6. off as paid time off.
  7. I have some personal matters to attend to that week that require
  8. me to be out of the office. I wanted to give you as much advance
  9. notice as possible so you can plan accordingly while I am away.
  10. Please let me know if you need any additional information from me
  11. or have any concerns with me taking next week off. I appreciate you
  12. considering this request.
  13. Thank you, [Your name]
  14. ### Response:
  15. Write an email to my boss that I need next week 08/01 - 08/04 off.

2. 创建指令数据集

在定义了我们的应用场景和提示模板后,我们需要创建自己的指令数据集。创建高质量的指令数据集是获得良好模型性能的关键。研究表明,“对齐,越少越好” 表明,创建高质量、低数量 (大约 1000 个样本) 的数据集可以达到与低质量、高数量的数据集相同的性能。

创建指令数据集有几种方法,包括:

  1. 使用现有数据集并将其转换为指令数据集,例如 FLAN
  2. 使用现有的 LLM 创建合成指令数据集,例如 Alpaca
  3. 人力创建指令数据集,例如 Dolly

每种方法都有其优缺点,这取决于预算、时间和质量要求。例如,使用现有数据集是最简单的,但可能不适合您的特定用例,而使用人力可能是最准确的,但必然耗时、昂贵。也可以结合几种不同方法来创建指令数据集,如 Orca: Progressive Learning from Complex Explanation Traces of GPT-4.

为了简单起见,我们将使用 Dolly,这是一个开源的指令跟踪记录数据集,由数千名 Databricks 员工在 InstructGPT paper 中描述的几个行为类别中生成,包括头脑风暴、分类、确定性回答、生成、信息提取、开放性回答和摘要。

开始编程吧,首先,我们来安装依赖项。

  1. !pip install "transformers==4.31.0" "datasets==2.13.0" "peft==0.4.0" "accelerate==0.21.0" "bitsandbytes==0.40.2" "trl==0.4.7" "safetensors>=0.3.1" --upgrade

我们使用 Datasets library 的 load_dataset() 方法加载 databricks/databricks-dolly-15k 数据集。

  1. from datasets import load_dataset
  2. from random import randrange
  3. # 从hub加载数据集
  4. dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
  5. print(f"dataset size: {len(dataset)}")
  6. print(dataset[randrange(len(dataset))])
  7. # dataset size: 15011

为了指导我们的模型,我们需要将我们的结构化示例转换为通过指令描述的任务集合。我们定义一个 formatting_function ,它接受一个样本并返回一个符合格式指令的字符串。

  1. def format_instruction(sample):
  2. return f"""### Instruction:
  3. Use the Input below to create an instruction, which could have been used to generate the input using an LLM.
  4. ### Input:
  5. {sample['response']}
  6. ### Response:
  7. {sample['instruction']}
  8. """

我们来在一个随机的例子上测试一下我们的结构化函数。

  1. from random import randrange
  2. print(format_instruction(dataset[randrange(len(dataset))]))

3. 使用 trlSFTTrainer 指令微调 Llama 2

我们将使用最近在由 Tim Dettmers 等人的发表的论文“QLoRA: Quantization-aware Low-Rank Adapter Tuning for Language Generation”中介绍的方法。QLoRA 是一种新的技术,用于在微调期间减少大型语言模型的内存占用,且并不会降低性能。QLoRA 的 TL;DR; 是这样工作的:

  • 将预训练模型量化为 4bit 位并冻结它。
  • 附加轻量化的、可训练的适配器层。(LoRA)
  • 在使用冻结的量化模型基于文本内容进行微调时,仅微调适配器层参数。

如果您想了解有关 QLoRA 及其工作原理的更多信息,我建议您阅读 Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA 博客文章。

Flash Attention (快速注意力)

Flash Attention 是一种经过重新排序的注意力计算方法,它利用经典技术 (排列、重计算) 来显著加快速度,将序列长度的内存使用量从二次降低到线性。它基于论文“FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”。

TL;DR; 将训练加速了 3 倍。在这儿获得更多信息 FlashAttention。 Flash Attention 目前仅支持 Ampere (A10, A40, A100, …) & Hopper (H100, …) GPU。 你可以检查一下你的 GPU 是否支持,并用下面的命令来安装它:

注意: 如果您的机器的内存小于 96GB,而 CPU 核心数足够多,请减少 MAX_JOBS 的数量。在我们使用的 g5.2xlarge 上,我们使用了 4

  1. python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
  2. pip install ninja packaging
  3. MAX_JOBS=4 pip install flash-attn --no-build-isolation

安装 flash attention 是会需要一些时间 (10-45 分钟)

该示例支持对所有 Llama 检查点使用 Flash Attention,但默认是未启用的。要开启 Flash Attention,请取消代码块中这段的注释, # COMMENT IN TO USE FLASH ATTENTION

  1. import torch
  2. from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
  3. use_flash_attention = False
  4. # COMMENT IN TO USE FLASH ATTENTION
  5. # replace attention with flash attention
  6. # if torch.cuda.get_device_capability()[0] >= 8:
  7. # from utils.llama_patch import replace_attn_with_flash_attn
  8. # print("Using flash attention")
  9. # replace_attn_with_flash_attn()
  10. # use_flash_attention = True
  11. # Hugging Face 模型id
  12. model_id = "NousResearch/Llama-2-7b-hf" # non-gated
  13. # model_id = "meta-llama/Llama-2-7b-hf" # gated
  14. # BitsAndBytesConfig int-4 config
  15. bnb_config = BitsAndBytesConfig(
  16. load_in_4bit=True,
  17. bnb_4bit_use_double_quant=True,
  18. bnb_4bit_quant_type="nf4",
  19. bnb_4bit_compute_dtype=torch.bfloat16
  20. )
  21. # 加载模型与分词器
  22. model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=False, device_map="auto")
  23. model.config.pretraining_tp = 1
  24. # 通过对比doc中的字符串,验证模型是在使用flash attention
  25. if use_flash_attention:
  26. from utils.llama_patch import forward
  27. assert model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__, "Model is not using flash attention"
  28. tokenizer = AutoTokenizer.from_pretrained(model_id)
  29. tokenizer.pad_token = tokenizer.eos_token
  30. tokenizer.padding_side = "right"

SFTTrainer 支持与 peft 的本地集成,这使得高效地指令微调LLM变得非常容易。我们只需要创建 LoRAConfig 并将其提供给训练器。

  1. from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
  2. # 基于 QLoRA 论文来配置 LoRA
  3. peft_config = LoraConfig(
  4. lora_alpha=16,
  5. lora_dropout=0.1,
  6. r=64,
  7. bias="none",
  8. task_type="CAUSAL_LM",
  9. )
  10. # 为训练准备好模型
  11. model = prepare_model_for_kbit_training(model)
  12. model = get_peft_model(model, peft_config)

在开始训练之前,我们需要定义自己想要的超参数 (TrainingArguments)。

  1. from transformers import TrainingArguments
  2. args = TrainingArguments(
  3. output_dir="llama-7-int4-dolly",
  4. num_train_epochs=3,
  5. per_device_train_batch_size=6 if use_flash_attention else 4,
  6. gradient_accumulation_steps=2,
  7. gradient_checkpointing=True,
  8. optim="paged_adamw_32bit",
  9. logging_steps=10,
  10. save_strategy="epoch",
  11. learning_rate=2e-4,
  12. bf16=True,
  13. tf32=True,
  14. max_grad_norm=0.3,
  15. warmup_ratio=0.03,
  16. lr_scheduler_type="constant",
  17. disable_tqdm=True # 当配置的参数都正确后可以关闭tqdm
  18. )

我们现在有了用来训练模型 SFTTrainer 所需要准备的每一个模块。

  1. from trl import SFTTrainer
  2. max_seq_length = 2048 # 数据集的最大长度序列
  3. trainer = SFTTrainer(
  4. model=model,
  5. train_dataset=dataset,
  6. peft_config=peft_config,
  7. max_seq_length=max_seq_length,
  8. tokenizer=tokenizer,
  9. packing=True,
  10. formatting_func=format_instruction,
  11. args=args,
  12. )

通过调用 Trainer 实例上的 train() 方法来训练我们的模型。

  1. # 训练
  2. trainer.train() # tqdm关闭后将不显示进度条信息
  3. # 保存模型
  4. trainer.save_model()

不使用 Flash Attention 的训练过程在 g5.2xlarge 上花费了 03:08:00。实例的成本为 1,212$/h ,总成本为 3.7$

使用 Flash Attention 的训练过程在 g5.2xlarge 上花费了 02:08:00。实例的成本为 1,212$/h ,总成本为 2.6$

使用 Flash Attention 的结果令人满意,速度提高了 1.5 倍,成本降低了 30%。

4. 测试模型、进行推理

在训练完成后,我们想要运行和测试模型。我们会使用 pefttransformers 将 LoRA 适配器加载到模型中。

  1. if use_flash_attention:
  2. # 停止 flash attention
  3. from utils.llama_patch import unplace_flash_attn_with_attn
  4. unplace_flash_attn_with_attn()
  5. import torch
  6. from peft import AutoPeftModelForCausalLM
  7. from transformers import AutoTokenizer
  8. args.output_dir = "llama-7-int4-dolly"
  9. # 加载基础LLM模型与分词器
  10. model = AutoPeftModelForCausalLM.from_pretrained(
  11. args.output_dir,
  12. low_cpu_mem_usage=True,
  13. torch_dtype=torch.float16,
  14. load_in_4bit=True,
  15. )
  16. tokenizer = AutoTokenizer.from_pretrained(args.output_dir)

我们来再次用随机样本加载一次数据集,试着来生成一条指令。

  1. from datasets import load_dataset
  2. from random import randrange
  3. # 从hub加载数据集并得到一个样本
  4. dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
  5. sample = dataset[randrange(len(dataset))]
  6. prompt = f"""### Instruction:
  7. Use the Input below to create an instruction, which could have been used to generate the input using an LLM.
  8. ### Input:
  9. {sample['response']}
  10. ### Response:
  11. """
  12. input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
  13. # with torch.inference_mode():
  14. outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9,temperature=0.9)
  15. print(f"Prompt:\n{sample['response']}\n")
  16. print(f"Generated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
  17. print(f"Ground truth:\n{sample['instruction']}")

太好了!我们的模型可以工作了!如果想要加速我们的模型,我们可以使用 Text Generation Inference 部署它。因此我们需要将我们适配器的参数合并到基础模型中去。

  1. from peft import AutoPeftModelForCausalLM
  2. model = AutoPeftModelForCausalLM.from_pretrained(
  3. args.output_dir,
  4. low_cpu_mem_usage=True,
  5. )
  6. # 合并 LoRA 与 base model
  7. merged_model = model.merge_and_unload()
  8. # 保存合并后的模型
  9. merged_model.save_pretrained("merged_model",safe_serialization=True)
  10. tokenizer.save_pretrained("merged_model")
  11. # push合并的模型到hub上
  12. # merged_model.push_to_hub("user/repo")
  13. # tokenizer.push_to_hub("user/repo")

原文作者: Philschmid

原文链接: https://www.philschmid.de/instruction-tune-llama-2

译者: Xu Haoran

扩展说明: 指令微调 Llama 2的更多相关文章

  1. 解密Prompt系列6. lora指令微调扣细节-请冷静,1个小时真不够~

    上一章介绍了如何基于APE+SELF自动化构建指令微调样本.这一章咱就把微调跑起来,主要介绍以Lora为首的低参数微调原理,环境配置,微调代码,以及大模型训练中显存和耗时优化的相关技术细节 标题这样写 ...

  2. 在一张 24 GB 的消费级显卡上用 RLHF 微调 20B LLMs

    我们很高兴正式发布 trl 与 peft 的集成,使任何人都可以更轻松地使用强化学习进行大型语言模型 (LLM) 微调!在这篇文章中,我们解释了为什么这是现有微调方法的有竞争力的替代方案. 请注意, ...

  3. COIG:开源四类中文指令语料库

    CHINESE OPEN INSTRUCTION GENERALIST: A PRELIMINARY RELEASE 论文:https://arxiv.org/pdf/2304.07987v1.pdf ...

  4. x86汇编指令详解

    80x86指令系统 80x86指令系统,指令按功能可分为以下七个部分. (1) 数据传送指令. (2) 算术运算指令. (3) 逻辑运算指令. (4) 串操作指令. (5) 控制转移指令. (6) 处 ...

  5. QPBOC扩展应用交易流程

    1 Q扩展部分数据 增加3个DGI,分别为:A001,8020,9020 9103中增加DF60(9F38中),DF61 增加DF62,DF63 1.1  A001扩展应用配置 DGI 长度 值(示例 ...

  6. x86汇编指令具体解释

    80x86指令系统 80x86指令系统,指令按功能可分为下面七个部分. (1) 数据传送指令. (2) 算术运算指令. (3) 逻辑运算指令. (4) 串操作指令. (5) 控制转移指令. (6) 处 ...

  7. x86 体系指令

    FASM 第二章 - 2.1 x86 体系指令 Author: 徐艺波  From: xuyibo.org  Updated: 2008-04-17   官方论坛   本站软件反馈.软件开发交流.   ...

  8. 第二百一十八节,jQuery EasyUI,TimeSpinner(时间微调)组件

    jQuery EasyUI,TimeSpinner(时间微调)组件 学习要点: 1.加载方式 2.属性列表 3.事件列表 4.方法列表 本节课重点了解 EasyUI 中 TimeSpinner(时间微 ...

  9. SIMD数据并行(二)——多媒体SIMD扩展指令集

    在计算机体系中,数据并行有两种实现路径:MIMD(Multiple Instruction Multiple Data,多指令流多数据流)和SIMD(Single Instruction Multip ...

  10. 第18章-x86指令集之常用指令

    x86的指令集可分为以下4种: 通用指令 x87 FPU指令,浮点数运算的指令 SIMD指令,就是SSE指令 系统指令,写OS内核时使用的特殊指令 下面介绍一些通用的指令.指令由标识命令种类的助记符( ...

随机推荐

  1. P1765

    和那道题一样,这次用的getchar,结果对了可是洛谷评测WA了,换成scanf单个字符,结果还是WA了,换成直接getline读入整个字符串就对了. 可见读入单个字符的方式有可能出现各种小错,尤其是 ...

  2. java 对象作为成员变量

    public class Main { private int uplimit; private int value; public Main(int uplimit){ this.uplimit = ...

  3. Python数据可视化-动态柱状图可视化

    Python数据可视化-动态柱状图可视化 一.基础柱状图 通过Bar构建基础柱状图 """ 演示基础柱状图的开发 """ from pyec ...

  4. nohup 与 >/dev/null 与 2>&1 作用与区别

    转载请注明出处: 在 Linux 中,>/dev/null 和 2>&1 是两个常用的重定向操作,它们用于控制命令的输出和错误信息.而且这两个参数经常 与 nohup 命令一起使用 ...

  5. Feign 进行rpc 调用时使用ribbon负载均衡源码解析

    转载请注明出处: Feign客户端接口的动态代理生成是基于JDK的动态代理来实现的,那么在所有的方法调用的时候最终都会走InvocationHandler接口的实现,默认就是ReflectiveFei ...

  6. wireshark 显示过滤表达式

    转载请注明出处: 1.根据协议过滤: 在显示过滤表达式的输入框中直接输入对应的协议类型即可:http   tcp  udp 2.根据 IP 过滤: 根据源IP地址过滤:如源地址IP为:127.0.0. ...

  7. 【mysql】 解决 auto_increment 字段 Column count doesn't match value count at row 1

    1, 表结构   man +-------+-------------+------+-----+---------+----------------+| id | int(11) | NO | PR ...

  8. [转帖]针对容器的nginx优化

    针对容器的nginx优化 本篇文章介绍了 Nginx 在容器内使用遇到的CPU核数获取问题以及对应的解决方法. 回顾上篇文章:TCP 半连接队列和全连接队列 背景 容器技术越来越普遍,很多公司已经将容 ...

  9. [转帖]Jmeter压力测试工具安装及使用教程

    https://www.cnblogs.com/monjeo/p/9330464.html 一.Jmeter下载 进入官网:http://jmeter.apache.org/ 1.第一步进入官网如下图 ...

  10. 【转帖】【性能提升神器】STRAIGHT_JOIN

    今天给大家下另一个性能提升神器-STRAIGHT_JOIN,在数据量大的联表查询中灵活运用的话,能大大缩短查询时间. 首先来解释下STRAIGHT_JOIN到底是用做什么的: STRAIGHT_JOI ...