微调类型简介

1. SFT监督微调:适用于在源任务中具有较高性能的模型进行微调,学习率较小。常见任务包括中文实体识别、语言模型训练、UIE模型微调。优点是可以快速适应目标任务,但缺点是可能需要较长的训练时间和大量数据。

2. LoRA微调:通过高阶矩阵秩的分解减少微调参数量,不改变预训练模型参数,新增参数。优点是减少了微调的参数量和成本,同时能达到与全模型微调相近的效果。

3. P-tuning v2微调:引入了prefix-tuning的思想,每一层都加入了prefix,并采用了多任务学习。解决了P-tuning v1中序列标注任务效果不佳和普遍性差的问题。其参数对象是各层的prefix。优点是适用于多任务学习,但在自然语言理解任务上表现可能不佳。

4. Freeze微调:主要用于大语言模型的微调,后几层网络提取语义特征,前几层提取文本表层特征。优点是参数高效,适用于提取特定层次的特征。

综上所述,各种微调方法适用于不同的场景和任务。SFT监督微调适用于快速适应目标任务,LoRA适用于减少参数量和成本,P-tuning v2适用于多任务学习,而Freeze适用于提取特定层次的特征。

1.下载glm2训练脚本

git clone https://github.com/THUDM/ChatGLM2-6B.git

2.然后使用 pip 安装依赖

pip install -r requirements.txt -i https://pypi.douban.com/simple/

运行行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖

pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/

3.下载样例数据或者自己构建样例

{"content": "类型#裙_材质#网纱_颜色#粉红色_图案#线条_图案#刺绣_裙腰型#高腰_裙长#连衣裙_裙袖长#短袖_裙领型#圆领", "summary": "这款连衣裙,由上到下都透出女性魅力,经典圆领型,开口度恰好,露出修长的脖颈线条,很是优雅气质,短袖设计,这款对身材有很好的修饰作用,穿起来很女神;裙身粉红色花枝重工刺绣,让人一眼难忘!而且在这种网纱面料上做繁复图案的绣花,是很考验工艺的,对机器的要求会更高,更加凸显我们的高品质做工;"}

可以根据以上格式,构建自己的训练样本,我们可以用一些行业生产数据,如会话记录对模型进行训练,

官方示例数据下载:

https%3A//cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/%3Fdl%3D1

4.根据自己的环境修改训练脚本中对应的文件地址

PRE_SEQ_LEN=128  #序列的预设长度为128
LR=2e-2 #学习率为0.02
NUM_GPUS=4 #用几颗GPU进行训练 torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS main.py \
--do_train \
--train_file /export/data/train.json \ #设置训练数据文件的目录
--validation_file /export/data/validation.json \ #设置验证文件的目录
--preprocessing_num_workers 10 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /opt/tritonserver/python_backend/models/chatglm2-6b \ #模型目录
--output_dir /export/models/trained-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \ #训练后的模型目录
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 128 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4

5.开始训练吧

sh train.sh

训练中

快要训练完成

6.训练完成

Training completed. Do not forget to share your model on huggingface.co/models =)

{'train_runtime': 4598.3849, 'train_samples_per_second': 41.754, 'train_steps_per_second': 0.652, 'train_loss': 0.1287700497706731, 'epoch': 2400.0}

100%|██████████| 3000/3000 [1:16:37<00:00, 1.53s/it]

***** train metrics *****

epoch = 2400.0

train_loss = 0.1288

train_runtime = 1:16:38.38

train_samples = 24

train_samples_per_second = 41.754

train_steps_per_second = 0.652

7.部署训练后的模型

在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重

        model_path = "/opt/tritonserver/python_backend/models/chatglm2-6b"
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join('/opt/train/trained-chatglm2-6b-pt-128-1e-4/checkpoint-3000', "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

8.过程中遇到的问题

8.1 微调后无法应答

PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1 torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS main.py \
--do_train \
--train_file train.json \
--validation_file dev.json \
--preprocessing_num_workers 10 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /opt/tritonserver/python_backend/models/chatglm2-6b \
--output_dir trained-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \

使用官方脚本中的学习率设置 LR=2e-2 (0.02)

模型出现无法应答,灾难性遗忘,基本上原有的知识都遗忘了,无法应答普通提问 , 比如"你好.."

于是尝试使用 LR=1e-4 (0.0001) 进行训练

"1e-4" 表示 1 乘以 10 的 -4 次方,即等于 0.0001,"2e-2" 表示 2 乘以 10 的 -2 次方,即等于 0.02。

模型最终可以应答.

镜像问题:

https://github.com/THUDM/ChatGLM-6B/issues/1148

8.2 关于学习率:

我理解是,学习率大小像看书看的粗细,看的太粗就学的快(收敛快)但啥也学不到,

学习率是影响模型训练效果的重要参数。过大的学习率可能导致模型不稳定,过小的学习率则可能导致训练速度变慢。因此,需要反复试验,找到合适的学习率。

学习率(lr)表示每次更新权重参数的尺度(步长),ΔΘ=Θ0−(lr)(loss′)。

学习率与batch_size在权重更新中的关系

学习率(lr)直观可以看出lr越大,权重更新的跨度越大,模型参数调整变化越快。

batch_size对模型的影响,在于模型每次更新时,计算梯度是计算整个Batch的平均梯度,

即权重更新公式中的loss′=1batchsize(lossbatch)′, 整合就是 ΔΘ=Θ0−(lr)1batchsize(lossbatch)′ 。即lr与batch_size共同影响模型更新。

作者:京东科技 杨建

来源:京东云开发者社区 转发请注明来源

基于 P-Tuning v2 进行 ChatGLM2-6B 微调实践的更多相关文章

  1. 【快报】基于K2 BPM的新一代协同办公门户实践交流会

    2014年2月28日,“基于BPM的新一代协同办公门户”用户实践交流活动在深圳金茂JW万豪酒店3楼Meet Room IV举办.本次会议由K2携手微软共同举办,邀请到的参会企业都是K2 的BPM老客户 ...

  2. 基于Sql Server 2008的分布式数据库的实践(五)

    原文 基于Sql Server 2008的分布式数据库的实践(五) 程序设计 ------------------------------------------------------------- ...

  3. 基于Sql Server 2008的分布式数据库的实践(四)

    原文 基于Sql Server 2008的分布式数据库的实践(四) 数据库设计 1.E-R图 2.数据库创建 Win 7 1 create database V3 Win 2003 1 create  ...

  4. 基于Sql Server 2008的分布式数据库的实践(三)

    原文 基于Sql Server 2008的分布式数据库的实践(三) 配置PHP 1.打开PHP配置文件,找到extension=php_mssql.dll,将前面的注释符号去掉 2.找到mssql.s ...

  5. 基于Sql Server 2008的分布式数据库的实践(二)

    原文 基于Sql Server 2008的分布式数据库的实践(二) 从Win7连接Win2003的Sql Server 2008 1.新建链接服务器链接到Win2003的Sql Server 2008 ...

  6. 基于Sql Server 2008的分布式数据库的实践(一)

    原文 基于Sql Server 2008的分布式数据库的实践(一) 配置Sql Server 2008(Win7) 1.打开SQL server2012,使用windows身份登录 2.登录后,右键选 ...

  7. 【公开课】【阿里在线技术峰会】魏鹏:基于Java容器的多应用部署技术实践

    对于公开课,可能目前用不上这些,但是往往能在以后想解决方案的时候帮助到我.以下是阿里对公开课的整理 摘要: 在首届阿里巴巴在线峰会上,阿里巴巴中间件技术部专家魏鹏为大家带来了题为<基于Java容 ...

  8. 滴滴出行基于RocketMQ构建企业级消息队列服务的实践

    小结: 1. https://mp.weixin.qq.com/s/v6NM3UgX-qTI7yO1QPCJrw 滴滴出行基于RocketMQ构建企业级消息队列服务的实践 原创: 江海挺 阿里巴巴中间 ...

  9. Python 基于Python从mysql表读取千万数据实践

    基于Python 从mysql表读取千万数据实践   by:授客 QQ:1033553122 场景:   有以下两个表,两者都有一个表字段,名为waybill_no,我们需要从tl_waybill_b ...

  10. 苏宁基于Spark Streaming的实时日志分析系统实践 Spark Streaming 在数据平台日志解析功能的应用

    https://mp.weixin.qq.com/s/KPTM02-ICt72_7ZdRZIHBA 苏宁基于Spark Streaming的实时日志分析系统实践 原创: AI+落地实践 AI前线 20 ...

随机推荐

  1. GPT3的技术突破:实现更准确、更真实的语言生成

    目录 1. 引言 2. 技术原理及概念 3. 实现步骤与流程 4. 应用示例与代码实现讲解 5. 优化与改进 6. 结论与展望 7. 附录:常见问题与解答 GPT-3 技术突破:实现更准确.更真实的语 ...

  2. vivo 自研鲁班分布式 ID 服务实践

    作者:vivo IT 平台团队- An Peng 本文介绍了什么是分布式ID,分布式ID的业务场景以及9种分布式ID的实现方式,同时基于vivo内部IT的业务场景,介绍了自研鲁班分布式ID服务的实践. ...

  3. 我坚定的认为,这个源码肯定是有 BUG 的!

    你好呀,我是歪歪. 上周我不是发了<我试图给你分享一种自适应的负载均衡.>这篇文章嘛,里面一种叫做"自适应负载均衡"的负载均衡策略,核心思路就是从多个服务提供者中随机选 ...

  4. 【SpringBoot】条件装配 @profile

    profile 使用说明: @profile注解的作用是指定类或方法在特定的 Profile 环境生效,任何@Component或@Configuration注解的类都可以使用@Profile注解. ...

  5. 【SpringBoot】整合Redis

    1.前言 最近公司在做项目,用到了redis,,发现自己一点都不会,然后就乘闲暇时间,自己学习一些redis相关的知识,在这里分享给像我一样的初学者. 2.我的项目结构: 2.1 pom.xml &l ...

  6. 《逆向工程核心原理》之DLL注入

    DLL注入 DLL注入指的是向运行中的其他进程强制插入特定的DLL文件.从技术细节来说,DLL注入命令其他进程自行调用LoadLibrary() API,加载(Loading)用户指定的DLL文件.D ...

  7. 使用Locust进行性能测试

    当涉及到评估应用程序或服务的性能时,Locust是一个功能强大且易于使用的开源工具.本文将介绍Locust的基本概念和使用方法. 什么是Locust? Locust是一个用于编写.运行和分析负载测试的 ...

  8. RestSharp HTTP请求库

    官方文档:https://restsharp.dev/intro.html#introduction c# RestSharp(http请求):https://blog.csdn.net/czjnoe ...

  9. js的一些小问题集合

    1.等于号的应用 function reverse(){ var checkbox = document.getElementsByName("hobby"); for (let ...

  10. Unity Shader编辑器工具类ShaderUtil 常用函数和用法

    Unity Shader编辑器工具类ShaderUtil 常用函数和用法 Unity的Shader编辑器工具类ShaderUtil提供了一系列函数,用于编译.导入和管理着色器.本文将介绍ShaderU ...