人工智能LLM模型:奖励模型的训练、PPO 强化学习的训练、RLHF

1.奖励模型的训练

1.1大语言模型中奖励模型的概念

在大语言模型完成 SFT 监督微调后,下一阶段是构建一个奖励模型来对问答对作出得分评价。奖励模型源于强化学习中的奖励函数,能对当前的状态刻画一个分数,来说明这个状态产生的价值有多少。在大语言模型微调中的奖励模型是对输入的问题和答案计算出一个分数。输入的答案与问题匹配度越高,则奖励模型输出的分数也越高。

1.2 奖励模型的模型架构与损失函数

1.2.1 模型架构

奖励模型(RM 模型)将 SFT 模型最后一层的 softmax 去掉,即最后一层不用 softmax,改成一个线性层。RM 模型的输入是问题和答案,输出是一个标量即分数。

由于模型太大不够稳定,损失值很难收敛且小模型成本较低,因此,RM 模型采用参数量为 6B 的模型,而不使用 175B 的模型。

1.2.2 损失函数

奖励模型的训练数据是人工对问题的每个答案进行排名,如下图所示:

对于每个问题,给出若干答案,然后工人进行排序,而奖励模型就是利用排序的结果来进行反向传播训练。奖励模型的损失函数采用 Pairwise Ranking Loss,公式如下所示:

$loss(θ)=−(K2​)1​E(x,yw​,yl​) D​[log(σ(rθ​(x,yw​)−rθ​(x,yl​)))]$

其中:

D:人工对答案进行排序的数据集;

x:数据集D中的问题;

K:每个问题对应的答案数量;

yw​yl​:问题x对应的K个答案中的两个,且yw​的排序比yl​高,由于是一对,也称 pairwiserθ​(x,y):需要训练的 RM 模型,对于输入的一对xy得到的标量分数;

θ:RM 模型需要优化的参数。

如何理解 RM 模型的损失函数呢?

RM 模型的目标是使得排序高的答案yw​对应的标量分数要高于排序低的答案yl​对应的标量分数,且越高越好,也就是使得损失函数中的rθ​(x,yw​)−rθ​(x,yl​)这个差值越大越好。将相减后的分数通过 sigmoid 函数,差值变成 - 1 到 1 之间,由于 sigmoid 函数是单调递增的函数,因此σ(rθ​(x,yw​)−rθ​(x,yl​))越大越好。σ(rθ​(x,yw​)−rθ​(x,yl​))约接近 1,表示yw​yl​排序高,属于 1 这个分类,反正属于 - 1 这个分类,所以这里也可以看成是一个二分类问题。再加上 logistic 函数,也就是相当于交叉熵损失函数。对于每个问题都有K个答案,在损失函数前除以CK2​,使得损失函数值不会因为K的变化而变化太多。损失函数的最终目标是最小化loss(θ),与最大化rθ​(x,yw​)−rθ​(x,yl​)相对应。

奖励模型中每个问题对应的答案数量即K值为什么选 9 更合适,而不是选择 4 呢?

  • 进行标注的时候,需要花很多时间去理解问题,但答案之间比较相近,假设 4 个答案进行排序要 30 秒时间,那么 9 个答案排序可能就 40 秒就够了。9 个答案与 4 个答案相比生成的问答对多了 5 倍,从效率上来看非常划算;
  • K=9时,每次计算 loss 都有 36 项rθ​(x,y)需要计算,RM 模型的计算所花时间较多,但可以通过重复利用之前算过的值(也就是只需要计算 9 次即可),能节约很多时间。

奖励模型的损失函数为什么会比较答案的排序,而不是去对每一个答案的具体分数做一个回归?

每个人对问题的答案评分都不一样,无法使用一个统一的数值对每个答案进行打分。如果采用对答案具体得分回归的方式来训练模型,会造成很大的误差。但是,每个人对答案的好坏排序是基本一致的。通过排序的方式避免了人为的误差。

1.3 总结

奖励模型通过与人类专家进行交互,获得对于生成响应质量的反馈信号,从而进一步提升大语言模型的生成能力和自然度。与监督模型不同的是,奖励模型通过打分的形式使得生成的文本更加自然逼真,让大语言模型的生成能力更进一步。

2.PPO 强化学习的训练

2.1 PPO 强化学习概念

大语言模型完成奖励模型的训练后,下一个阶段是训练强化学习模型(RL 模型),也是最后一个阶段。大语言模型微调中训练 RL 模型采用的优化算法是 PPO(Proximal Policy Optimization,近端策略优化)算法,即对设定的目标函数通过随机梯度下降进行优化。近端策略优化是一种深度强化学习算法,用于训练智能体在复杂环境中学习和执行任务。通过智能体的训练,使得其在与环境的交互中能够最大化累积回报,从而达成指定任务目标。这里的智能体在大语言模型中指的就是 RL 模型。

2.2 PPO 强化学习原理

RL 模型的初始模型采用 SFT 微调之后的大语言预训练模型。训练 RL 模型的数据集只需要收集问题集(Prompt 集),不需要对问题进行标注。问题集通过 RL 模型生成答案文本,然后将问题和答案输入上一步训练的 RW 模型进行打分,来评价生成的文本质量,而训练 RL 模型的目标是使得生成的文本要在 RW 模型上获得尽可能高的得分。

将初始语言模型的微调任务建模为强化学习(RL)问题,需要定义策略(policy)、动作空间(action space)和奖励函数(reward function)等基本要素。

策略就是基于该语言模型,接收 prompt 作为输入,然后输出一系列文本(或文本的概率分布);而动作空间就是词表所有 token 在所有输出位置的排列组合;观察空间则是可能的输入 token 序列(即 prompt),为词表所有 token 在所有输入位置的排列组合;而奖励函数则是上一阶段训好的 RM 模型,配合一些策略层面的约束进行的奖励计算。该阶段流程如下图所示:

RL 模型训练的损失函数公式如下:

$objective(ϕ)=E(x,y)∼DπϕRL​​​[rθ​(x,y)−βlog(πϕRL​(y∣x)/πSFT(y∣x))]+γEx∼Dpretrain​​[log(πϕRL​(x))]$

其中:

πSFT:SFT 模型;

πϕRL​:强化学习中,模型叫做 Policy,πϕRL​就是需要调整的模型,即最终模型。初始化是πSFT(x,y)∼DπϕRL​​x是 RL 数据集中的问题,yx通过πϕRL​模型得到的答案;

rθ​(x,y):对问题x和答案y进行打分的 RM 模型;

πϕRL​(y∣x):问题x通过πϕRL​得到答案y的概率,即对于每一个y的预测和它的 softmax 的输出相乘;

πSFT(y∣x):问题x通过πSFT得到答案y的概率;

x∼Dpretrain​x是来自大语言模型预训练阶段的数据;

βγ:调整系数。

RL 模型的优化目标是使得损失函数越大越好,损失函数可以分为三个部分,打分部分、KL 散度部分以及预训练部分。

  • 打分部分:将 RL 模型的问题数据集x,通过πϕRL​模型得到答案y,然后再把这对(x,y)代入 RW 模型进行打分,即损失函数公式中的rθ​(x,y)。该分数越高,代表模型生成的答案越好。
  • KL 散度部分:在每次更新参数后,πϕRL​会发生变化,x通过πϕRL​生成的y也会发生变化,而rθ​(x,y)奖励模型是根据πSFT模型的数据训练而来。如果πϕRL​πSFT差的太多,则会导致rθ​(x,y)的分数估算不准确。因此需要通过 KL 散度来计算,πϕRL​生成的答案分布和πSFT生成的答案分布之间的距离,使得两个模型之间不要差的太远。损失函数公式中的log(πϕRL​(y∣x)/πSFT(y∣x))就是在计算 KL 散度。由于 KL 散度是越小越好,而训练目标是损失函数越大越好,因此在前面需要加上一个负号。
  • 预训练部分:预训练部分对应损失函数中的Ex∼Dpretrain​​[log(πϕRL​(x))]。如果没有该项,那么模型最终可能只对这一个任务能够做好,在别的任务上会发生性能下降。因此,需要将预训练阶段的目标函数加上,使得前面两个部分在新的数据集上做拟合的同时保证原始的数据也不会丢弃。

最终优化后的πϕRL​模型就是大语言模型的最终模型。

2.3 总结

通过强化学习的训练方法,迭代式的更新奖励模型(RW 模型)以及策略模型(RL 模型),让奖励模型对模型输出质量的刻画愈加精确,策略模型的输出则愈能与初始模型拉开差距,使得输出文本变得越来越符合人的认知。这种训练方法也叫做 RLHF。

目前,RLHF 技术对训练大语言模型具有极大的影响力,训练出来的效果好于之前的方法。但是,RLHF 训练出来的大语言模型仍然可能输出有害或事实上不准确的文本,需要不断不断改进。此外,在基于 RLHF 范式训练模型时,人工标注的成本还是非常高昂的,RLHF 性能最终仅能达到标注人员的知识水平。这里的人工标注主要是为 RM 模型标注输出文本的排序结果,而若想要用人工去撰写答案的方式来训练模型,那成本更是不可想象。

3.关键知识点

  1. 大语言模型微调中的奖励模型训练:1.奖励模型输入问答对,输出得分 2.奖励模型的损失函数目的是使得得分较高的答案比得分较低的答案尽可能大,3.奖励模型是判别式模型

  2. 奖励模型是:监督学习、强化学习、判别式模型

  3. 大语言模型训练中的PPO强化学习:1.在大语言模型训练中,强化学习模型架构与SFT监督微调的模型一样,2.RLHF中训练强化学习模型阶段不需要标注问题的答案 3.RLHF中的初始策略就是SFT模型

  4. 关于RLHF方法中RL模型训练的损失函数:1.RL模型的损失函数包含三个部分 2.RL模型的损失函数需要计算策略更新后的RL模型与SFT模型输出的KL散度 3.RL模型的损失函数需要计算大语言模型预训练阶段的损失函数 4.RL模型的损失函数要使得RL模型生成的文本在奖励模型中的得分越高越好

  5. RLHF本质上是通过人类的反馈来优化模型,生成的文本会更加的自然。

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

人工智能LLM模型:奖励模型的训练、PPO 强化学习的训练、RLHF的更多相关文章

  1. 强化学习(十七) 基于模型的强化学习与Dyna算法框架

    在前面我们讨论了基于价值的强化学习(Value Based RL)和基于策略的强化学习模型(Policy Based RL),本篇我们讨论最后一种强化学习流派,基于模型的强化学习(Model Base ...

  2. 强化学习中的无模型 基于值函数的 Q-Learning 和 Sarsa 学习

    强化学习基础: 注: 在强化学习中  奖励函数和状态转移函数都是未知的,之所以有已知模型的强化学习解法是指使用采样估计的方式估计出奖励函数和状态转移函数,然后将强化学习问题转换为可以使用动态规划求解的 ...

  3. 强化学习 3—— 使用蒙特卡洛采样法(MC)解决无模型预测与控制问题

    一.问题引入 回顾上篇强化学习 2 -- 用动态规划求解 MDP我们使用策略迭代和价值迭代来求解MDP问题 1.策略迭代过程: 1.评估价值 (Evaluate) \[v_{i}(s) = \sum_ ...

  4. 基于TORCS和Torch7实现端到端连续动作自动驾驶深度强化学习模型(A3C)的训练

    基于TORCS(C++)和Torch7(lua)实现自动驾驶端到端深度强化学习模型(A3C-连续动作)的训练 先占坑,后续内容有空慢慢往里填 训练系统框架 先占坑,后续内容有空慢慢往里填 训练系统核心 ...

  5. ICML 2018 | 从强化学习到生成模型:40篇值得一读的论文

    https://blog.csdn.net/y80gDg1/article/details/81463731 感谢阅读腾讯AI Lab微信号第34篇文章.当地时间 7 月 10-15 日,第 35 届 ...

  6. 伯克利、OpenAI等提出基于模型的元策略优化强化学习

    基于模型的强化学习方法数据效率高,前景可观.本文提出了一种基于模型的元策略强化学习方法,实践证明,该方法比以前基于模型的方法更能够应对模型缺陷,还能取得与无模型方法相近的性能. 引言 强化学习领域近期 ...

  7. 强化学习之五:基于模型的强化学习(Model-based RL)

    本文是对Arthur Juliani在Medium平台发布的强化学习系列教程的个人中文翻译,该翻译是基于个人分享知识的目的进行的,欢迎交流!(This article is my personal t ...

  8. 强化学习之 免模型学习(model-free based learning)

    强化学习之 免模型学习(model-free based learning) ------ 蒙特卡罗强化学习 与 时序查分学习 ------ 部分节选自周志华老师的教材<机器学习> 由于现 ...

  9. 【模型压缩】MetaPruning:基于元学习和AutoML的模型压缩新方法

    论文名称:MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning 论文地址:https://arxiv.org/ ...

  10. 论文:利用深度强化学习模型定位新物体(VISUAL SEMANTIC NAVIGATION USING SCENE PRIORS)

    这是一篇被ICLR 2019 接收的论文.论文讨论了如何利用场景先验知识 (scene priors)来定位一个新场景(novel scene)中未曾见过的物体(unseen objects).举例来 ...

随机推荐

  1. MAC 转 Byte[] 数组

    MAC 转 Byte[] 数组 /** * MAC 地址转 byte[] * 默认以小端序转换 * * @param macAddr "E4:54:E8:81:FC:FD" * @ ...

  2. MAC zsh:no matches found

    jimmy@MacBook-Pro bin % wsdl2java http://www.webxml.com.cn/WebServices/IpAddressSearchWebService.asm ...

  3. JVM HotSpot 可达性分析算法实现细节

    本文部分摘自<深入理解 Java 虚拟机第三版> 根节点枚举 在之前关于可达性分析算法的介绍中我们讲过,我们需要先找出可固定作为 GC Roots 的节点,然后沿着引用链去寻找那些无用的垃 ...

  4. GO语言之环境搭建和基本命令

    目录 go语言基础 下载go编译器 go目录简介 gopath简介 环境变量配置 GOPATH PATH go语言项目结构 IDE下载与配置 安装goland goland里添加goroot和gopa ...

  5. 阿里云视频云人脸生成领域最新研究成果入选CVPR2022

    CVPR(IEEE Conference on Computer Vision and Pattern Recognition)作为计算机视觉和模式识别领域的顶级会议,在全球具有极高的权威性.目前在中 ...

  6. Google C++编程规范(Google C++ Style Guide)

    参考链接: Google 代码规范 C++总结 Google 开源项目风格指南--中文版 Google C++ Style Guide是一份不错的C++编码指南,我制作了一张比较全面的说明图,可以在短 ...

  7. Golang之文件系统事件监听

    Golang之文件系统事件监听 基本介绍 文件系统事件是指文件系统相关的各种操作和状态变化,当一个应用层的进程操作文件或目录时,会触发system call,内核的notification子系统可以守 ...

  8. 如何把thinkphp5的项目迁移到阿里云函数计算来应对流量洪峰?

    原文链接:https://developer.aliyun.com/article/982746 1. 为什么要迁移到阿里云函数? 我的项目是一个节日礼品领取项目,过节的时候会有短时间的流量洪峰.平时 ...

  9. linux有用的命令

    如下是一些在工作中偶尔会用到,每次用的时候都要查一查资料的命令这里总结一下方便今后查阅 0.查看操作系统版本 cat /etc/issue  或 cat /etc/redhat-release 1.后 ...

  10. mybatis-plus-QueryWrapper like的用法

    mybatis-plus 中想写like的语句 一.直接用 QueryWrapper 中的 like String deptLevelCodeTemp = "1000010001" ...