基于人类反馈的强化学习,Reinforcement Learning from Human Feedback (RLHF)
基于人类反馈的强化学习, RLHF,转载参考链接
RLHF 是一项涉及多个模型和不同训练阶段的复杂概念,可以按三个步骤分解:
- 预训练一个语言模型 (LM) ;
- 聚合问答数据并训练一个奖励模型 (Reward Model,RM) ;
- 用强化学习 (RL) 方式微调 LM。
Step 1. 预训练语言模型
首先,我们使用经典的预训练目标训练一个语言模型。对这一步的模型,OpenAI 在其第一个流行的 RLHF 模型 InstructGPT 中使用了较小版本的 GPT-3; Anthropic 使用了 1000 万 ~ 520 亿参数的 Transformer 模型进行训练;DeepMind 使用了自家的 2800 亿参数模型 Gopher。
这里可以用额外的文本或者条件对这个 LM 进行微调,例如 OpenAI 对 “更可取” (preferable) 的人工生成文本进行了微调,而 Anthropic 按 “有用、诚实和无害” 的标准在上下文线索上蒸馏了原始的 LM。这里或许使用了昂贵的增强数据,但并不是 RLHF 必须的一步。由于 RLHF 还是一个尚待探索的领域,对于” 哪种模型” 适合作为 RLHF 的起点并没有明确的答案。
接下来,我们会基于 LM 来生成训练 奖励模型 (RM,也叫偏好模型) 的数据,并在这一步引入人类的偏好信息。
Step 2. 训练奖励模型
RM 的训练是 RLHF 区别于旧范式的开端。这一模型接收一系列文本并返回一个标量奖励,数值上对应人的偏好。我们可以用端到端的方式用 LM 建模,或者用模块化的系统建模 (比如对输出进行排名,再将排名转换为奖励) 。这一奖励数值将对后续无缝接入现有的 RL 算法至关重要。
关于模型选择方面,RM 可以是另一个经过微调的 LM,也可以是根据偏好数据从头开始训练的 LM。例如 Anthropic 提出了一种特殊的预训练方式,即用偏好模型预训练 (Preference Model Pretraining,PMP) 来替换一般预训练后的微调过程。因为前者被认为对样本数据的利用率更高。但对于哪种 RM 更好尚无定论。
关于训练文本方面,RM 的提示 - 生成对文本是从预定义数据集中采样生成的,并用初始的 LM 给这些提示生成文本。Anthropic 的数据主要是通过 Amazon Mechanical Turk 上的聊天工具生成的,并在 Hub 上 可用,而 OpenAI 使用了用户提交给 GPT API 的 prompt。
关于训练奖励数值方面,这里需要人工对 LM 生成的回答进行排名。起初我们可能会认为应该直接对文本标注分数来训练 RM,但是由于标注者的价值观不同导致这些分数未经过校准并且充满噪音。通过排名可以比较多个模型的输出并构建更好的规范数据集。
对具体的排名方式,一种成功的方式是对不同 LM 在相同提示下的输出进行比较,然后使用 Elo 系统建立一个完整的排名。这些不同的排名结果将被归一化为用于训练的标量奖励值。
这个过程中一个有趣的产物是目前成功的 RLHF 系统使用了和生成模型具有 不同 大小的 LM (例如 OpenAI 使用了 175B 的 LM 和 6B 的 RM,Anthropic 使用的 LM 和 RM 从 10B 到 52B 大小不等,DeepMind 使用了 70B 的 Chinchilla 模型分别作为 LM 和 RM) 。一种直觉是,偏好模型和生成模型需要具有类似的能力来理解提供给它们的文本。
接下来是最后一步:利用 RM 输出的奖励,用强化学习方式微调优化 LM。
Step 3. 用强化学习微调
长期以来出于工程和算法原因,人们认为用强化学习训练 LM 是不可能的。而目前多个组织找到的可行方案是使用策略梯度强化学习 (Policy Gradient RL) 算法、近端策略优化 (Proximal Policy Optimization,PPO) 微调初始 LM 的部分或全部参数。因为微调整个 10B~100B+ 参数的成本过高 (相关工作参考低秩适应 LoRA 和 DeepMind 的 Sparrow LM) 。PPO 算法已经存在了相对较长的时间,有大量关于其原理的指南,因而成为 RLHF 中的有利选择。
事实证明,RLHF 的许多核心 RL 进步一直在弄清楚如何将熟悉的 RL 算法应用到更新如此大的模型。
让我们首先将微调任务表述为 RL 问题。首先,该 策略 (policy) 是一个接受提示并返回一系列文本 (或文本的概率分布) 的 LM。这个策略的 行动空间 (action space) 是 LM 的词表对应的所有词元 (一般在 50k 数量级) ,观察空间 (observation space) 是可能的输入词元序列,也比较大 (词汇量 ^ 输入标记的数量) 。奖励函数 是偏好模型和策略转变约束 (Policy shift constraint) 的结合。
PPO 算法确定的奖励函数具体计算如下:将提示 x 输入初始 LM 和当前微调的 LM,分别得到了输出文本 y1, y2,将来自当前策略的文本传递给 RM 得到一个标量的奖励 rθ。将两个模型的生成文本进行比较计算差异的惩罚项,在来自 OpenAI、Anthropic 和 DeepMind 的多篇论文中设计为输出词分布序列之间的 Kullback–Leibler (KL) divergence 散度的缩放,即 KLr=rθ−λrKL 。这一项被用于惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中生成乱码文本来愚弄奖励模型提供高奖励值。此外,OpenAI 在 InstructGPT 上实验了在 PPO 添加新的预训练梯度,可以预见到奖励函数的公式会随着 RLHF 研究的进展而继续进化。
最后根据 PPO 算法,我们按当前批次数据的奖励指标进行优化 (来自 PPO 算法 on-policy 的特性) 。PPO 算法是一种信赖域优化 (Trust Region Optimization,TRO) 算法,它使用梯度约束确保更新步骤不会破坏学习过程的稳定性。DeepMind 对 Gopher 使用了类似的奖励设置,但是使用 A2C (synchronous advantage actor-critic) 算法来优化梯度。
作为一个可选项,RLHF 可以通过迭代 RM 和策略共同优化。随着策略模型更新,用户可以继续将输出和早期的输出进行合并排名。Anthropic 在他们的论文中讨论了 迭代在线 RLHF,其中策略的迭代包含在跨模型的 Elo 排名系统中。这样引入策略和 RM 演变的复杂动态,代表了一个复杂和开放的研究问题。
RLHF 的未来
尽管 RLHF 取得了一定的成果和关注,但依然存在局限。这些模型依然会毫无不确定性地输出有害或者不真实的文本。这种不完美也是 RLHF 的长期挑战和动力 —— 在人类的固有领域中运行意味着永远不会到达一个完美的标准。
收集人类偏好数据的质量和数量决定了 RLHF 系统性能的上限。RLHF 系统需要两种人类偏好数据:人工生成的文本和对模型输出的偏好标签。生成高质量回答需要雇佣兼职人员 (而不能依赖产品用户和众包) 。另一方面,训练 RM 需要的奖励标签规模大概是 50k 左右,所以并不那么昂贵 (当然远超了学术实验室的预算) 。目前相关的数据集只有一个基于通用 LM 的 RLHF 数据集 (来自 Anthropic 和几个较小的子任务数据集 (如来自 OpenAI 的摘要数据集) 。另一个挑战来自标注者的偏见。几个人类标注者可能有不同意见,导致了训练数据存在一些潜在差异。
除开数据方面的限制,一些有待开发的设计选项可以让 RLHF 取得长足进步。例如对 RL 优化器的改进方面,PPO 是一种较旧的算法,但目前没有什么结构性原因让其他算法可以在现有 RLHF 工作中更具有优势。另外,微调 LM 策略的一大成本是策略生成的文本都需要在 RM 上进行评估,通过离线 RL 优化策略可以节约这些大模型 RM 的预测成本。最近,出现了新的 RL 算法如隐式语言 Q 学习 (Implicit Language Q-Learning,ILQL) 也适用于当前 RL 的优化。在 RL 训练过程的其他核心权衡,例如探索和开发 (exploration-exploitation) 的平衡也有待尝试和记录。探索这些方向至少能加深我们对 RLHF 的理解,更进一步提升系统的表现。
注意以上信息,全部转载于Huggingface Blog:https://huggingface.co/blog/zh/rlhf Lambert, et al., "Illustrating Reinforcement Learning from Human Feedback (RLHF)", Hugging Face Blog, 2022.
(1) 应该关注那些metrics?
当对语言模型进行经典的监督微调时,损失(尤其是验证损失validation loss)可以作为训练进度的一个很好的指标。然而,在强化学习(RL)中,损失对模型性能的信息变得越来越少,而且它的值可能会随着实际性能的提高而波动。
为了解决这个问题,我们建议首先关注两个关键指标: Mean Reward 和 Objective KL Divergence
平均奖励Mean Reward: 主要目标是使模型在强化学习训练期间获得的奖励最大化。目标KL散度Objective KL Divergence: KL散度(Kullback-Leibler Divergence)衡量两个概率分布之间的不相似性。在强化学习训练的背景下,我们用它来量化当前模型和参考模型之间的差异。理想情况下,我们希望将KL散度保持在0到10之间,以确保模型生成的文本与参考模型生成的文本保持接近。其他指标
(2)我们为什么要使用参考模型? KL散度的目的是什么?
当训练RL模型时,仅针对奖励进行优化可能会导致意想不到的行为,其中模型以与良好语言生成不一致的方式利用环境。在RLHF的情况下,我们使用一个奖励模型来预测生成的文本是否被人类高度排名。
然而,针对奖励模型进行优化的强化学习模型可能会学习到产生高奖励但不代表好的语言的模式。这可能会导致极端情况,即模型生成带有过多感叹号或表情符号的文本,以最大化奖励。在一些最坏的情况下,该模型可能生成与自然语言完全无关的模式,但却获得了很高的回报,类似于对抗性攻击。
为了解决这个问题,我们基于当前模型和参考模型之间的KL分歧,在奖励函数中添加了一个惩罚。通过这样做,我们鼓励模型与参考模型生成的模型保持接近。
(3)对负KL散度的关注是什么?
如果你纯粹通过从模型分布中抽样来生成文本,一般情况下工作得很好。但是当你使用生成方法时,有一些注意事项,因为它并不总是纯粹的采样取决于可能导致kl散度为负的设置。本质上,当活动模型达到log_p_token_active < log_p_token_ref时,我们得到负KL-div。这可能在几种情况下发生:
- Top-k采样:模型可以平滑概率分布,导致Top-k个令牌的概率小于参考模型的概率,但仍然被选中;
- min_length:忽略EOS令牌,直到达到min_length。因此,该模型可以为EOS token分配非常高的日志概率,为所有其他令牌分配非常低的日志概率,直到达到min_length;
- 批处理生成:批处理中已完成的序列被填充,直到所有生成都完成。模型可以学习为padding tokens 分配非常低的概率,除非它们被适当地屏蔽或删除。
这只是几个例子。为什么负KL是一个问题? 总奖励R被计算为R = R - beta * KL,所以如果模型能够学习如何驱动KL-divergence为负,它就能有效地获得正奖励。在许多情况下,在生成过程中利用这种漏洞比实际学习奖励函数要容易得多。此外,KL可以变得任意小,因此实际奖励可能非常小。
(4)如何在训练中生成文本?
为了避免上述KL问题,我们建议使用以下设置
generation_kwargs = {
"min_length": -1, # don't ignore the EOS token (see above)
"top_k": 0.0, # no top-k sampling
"top_p": 1.0, # no nucleus sampling
"do_sample": True, # yes, we want to sample
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
"max_new_tokens": 32, # specify how many tokens you want to generate at most
}
使用这些设置,我们通常不会遇到任何问题。你也可以尝试其他设置,但如果你遇到负kl散度的问题,试着回到这些设置,看看它们是否持续存在。
基于人类反馈的强化学习,Reinforcement Learning from Human Feedback (RLHF)的更多相关文章
- 强化学习(Reinforcement Learning)中的Q-Learning、DQN,面试看这篇就够了!
1. 什么是强化学习 其他许多机器学习算法中学习器都是学得怎样做,而强化学习(Reinforcement Learning, RL)是在尝试的过程中学习到在特定的情境下选择哪种行动可以得到最大的回报. ...
- 强化学习 reinforcement learning: An Introduction 第一章, tic-and-toc 代码示例 (结构重建版,注释版)
强化学习入门最经典的数据估计就是那个大名鼎鼎的 reinforcement learning: An Introduction 了, 最近在看这本书,第一章中给出了一个例子用来说明什么是强化学习, ...
- 基于Keras的OpenAI-gym强化学习的车杆/FlappyBird游戏
强化学习 课程:Q-Learning强化学习(李宏毅).深度强化学习 强化学习是一种允许你创造能从环境中交互学习的AI Agent的机器学习算法,其通过试错来学习.如上图所示,大脑代表AI Agent ...
- 强化学习(Reinfment Learning) 简介
本文内容来自以下两个链接: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/ https: ...
- 增强学习Reinforcement Learning经典算法梳理3:TD方法
转自:http://blog.csdn.net/songrotek/article/details/51382759 博客地址:http://blog.csdn.net/songrotek/artic ...
- Deep Learning专栏--强化学习之MDP、Bellman方程(1)
本文主要介绍强化学习的一些基本概念:包括MDP.Bellman方程等, 并且讲述了如何从 MDP 过渡到 Reinforcement Learning. 1. 强化学习基本概念 这里还是放上David ...
- Reinforcement Learning,微信公众号:DRL学习
欢迎大家关注微信公众号:DRL学习,我们一起来学习强化学习和深度强化学习的算法及现状应用问题. 强化学习简单说就是学习如何最大化未来奖励的预期总和,以及agent学会在环境中做出的行动序列,其中随机状 ...
- 【强化学习】MOVE37-Introduction(导论)/马尔科夫链/马尔科夫决策过程
写在前面的话:从今日起,我会边跟着硅谷大牛Siraj的MOVE 37系列课程学习Reinforcement Learning(强化学习算法),边更新这个系列.课程包含视频和文字,课堂笔记会按视频为单位 ...
- Ⅰ Introduction to Reinforcement Learning
Dictum: To spark, often burst in hard stone. -- William Liebknecht 强化学习(Reinforcement Learning)是模仿人 ...
- 【转】强化学习(一)Deep Q-Network
原文地址:https://www.hhyz.me/2018/08/05/2018-08-05-RL/ 1. 前言 虽然将深度学习和增强学习结合的想法在几年前就有人尝试,但真正成功的开端就是DeepMi ...
随机推荐
- 012_DRC检查与处理
Check entire design:DRC检查整个原理图: Check Selection:DRC检查选择的部分电路: Use occurrences:选择所有事件进行检查: Use instan ...
- 几种常见Ruby on Rails内置方法介绍
Ruby on Rails是一个功能强大的WEB开发框架,在这里我们将会学到一些经常用到的Ruby on Rails内置方法,帮助大家熟练掌握其应用技巧. Ruby on Rails自动生成文档技巧大 ...
- golang url解析
package main import "fmt" import "net/url" import "strings" func main( ...
- Android开发环境配置 JDK及SDK
已经搭建过无数次开发环境,今天把搭建环境记录下,下次不用去搜索别人博客,有些博客都是复制粘贴,有些关键信息都缺失了. 1.首先第一步:下载JDK,配置JDK环境变量.JDK可以在Oracle官网下载, ...
- 几行命令用minikube快速搭建可测试的kubernetes单节点环境
几行命令用minikube快速搭建可测试的kubernetes单节点环境 需要docker环境,https://www.cnblogs.com/xiaofei12/p/17544579.html,网速 ...
- 基于webapi的websocket聊天室(四)
上一篇实现了多聊天室.这一片要继续改进的是实现收发文件,以及图片显示. 效果 问题 websocket本身就是二进制传输.文件刚好也是二进制存储的. 文件本身的传输问题不太,但是需要传输文件元数据,比 ...
- js毫秒转时分秒
const formatSeconds = (value) => { if (value === 0 || value < 1000) return '0秒'; var timestamp ...
- Sqlserver存储过程中使用try-catch和事务
BEGIN TRY BEGIN TRANSACTION --逻辑代码 COMMIT TRANSACTION --提交事务 END TRY BEGIN CATCH SELECT @Msg = ERROR ...
- Python 数据降级(重采样)
在数据处理中,经常有高频数据转成低频,秒级数据转成分钟.小时数据等.我们将讨论以下方法: 使用 Pandas 的 resample 方法: 示例:将天数据转化成月数据. 代码示例: import pa ...
- Flask学习记录:在w3cschool资料的基础上的个人摘录、实践与总结
学习与转载自w3cschool,在w3cschool资料的基础上的个人摘录.实践与总结,如有错误望留言. 一.Flask 概述 2021-08-25 14:01 更新 1.1 什么是Web Frame ...