DPO: Direct Preference Optimization 直接偏好优化(学习笔记)
学习参考:链接1
一、为什么要提出DPO
在之前,我们已经了解到基于人类反馈的强化学习RLHF分为三个阶段:全监督微调(SFT)、奖励模型(RM)、强化学习(PPO)。但是RLHF面临缺陷:RLHF 是一个复杂且经常不稳定的过程,首先拟合反映人类偏好的奖励模型,然后使用强化学习微调大型无监督 LM,以最大化这种估计奖励,而不会偏离原始模型太远。为解决这一问题,提出一个直接偏好优化 (DPO) 的新算法:通过利用奖励函数与最优策略之间的映射关系,证明这个受限的奖励最大化问题可以通过单阶段的策略训练来精确优化,本质上是在人类偏好数据上解决一个分类问题。DPO是稳定的、性能和计算成本轻量级的,无需拟合奖励模型,在微调期间从 LM 中采样,或执行显着的超参数调整。通过实验表明:DPO 进行微调超过了 RLHF 效果,并提高了摘要和单轮对话的响应质量。
二、什么是DPO
DPO,一种基于人类偏好优化语言模型的新方法。与RLHF不同,DPO不依赖于明确的奖励建模或强化学习。它针对与RLHF相同的目标,但提供了一种更简单、更直接的培训方法。
DPO的工作原理:增加偏好样本的对数概率与减小非偏好样本响应的对数概率。它结合了动态加权机制,以避免仅使用概率比目标时遇到的模型退化问题。
DPO依赖于理论上的偏好模型,如Bradley-Terry模型,来测量奖励函数与经验偏好数据的对齐程度。与传统方法不同,传统方法使用偏好模型来训练奖励模型,然后基于该奖励模型训练策略,DPO直接根据策略定义偏好损失。给定一个关于模型响应的人类偏好数据集,DPO可以使用简单的二元交叉熵目标来优化策略,无需在训练过程中明确学习奖励函数或从策略中采样。具体推导见链接1
(1)原RLHF的优化目标:最大化奖励和最小化参考策略的KL散度
(2)DPO优化目标:利用了从奖励函数到最优策略的解析映射,允许直接使用人类偏好数据进行简化的优化过程
该目标增加了对偏好数据$y_w$的可能性,并减少了非偏好数据$y_l$的可能性。这些示例按照隐式奖励模型的评级加权,由$\beta$缩放.
DPO重参数化等效于具有隐式奖励函数:
参数模型$\pi_{\theta}$的优化等效于在此变量更改下的奖励模型优化。
(3)DPO在干什么?
为了从原理上理解 DPO,分析损失函数的梯度$L_{DPO} $。 相对于参数 θ 的梯度可以写为:
其中是由语言模型$\pi_{\theta}$和参考模型$\pi_{ref}$隐式定义的奖励函数。直观上,损失函数 $L_{DPO} $的梯度增加了偏好$y_w$ 的可能性,并降低了非偏好$y_l$的可能性。更重要的是,样例的权重是通过: 隐式奖励模型$\hat{r}_{\theta}$对非偏好的评分高多少来衡量的,即$\hat{r}_{\theta}(x,y_l)-\hat{r}_{\theta}(x,y_w)$,按 β 进行缩放,即隐式奖励模型认为策略模型错误的程度。 我们的实验表明了这种加权的重要性,因为没有加权系数的这种方法的简单版本可能会导致语言模型退化。
(4)DPO outline
步骤1)是在构造数据集,通过对同一问题的两种回复的倾向性:chosen or rejected,反映人类偏好。
步骤2)在于优化,具体过程大概是,对于同一个question prompt,模型在两种模型:language/policy model 和 reference model下分别生成,对应chosen 和 rejected label真值标签的生成概率,因此可以获得四种概率值:policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, 用于DPO loss计算。
1、DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。
数据集由三部分组成:
prompt
chosen
rejected
可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf
dataset
2、 预期模型格式
与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。
3、使用 DPOTrainer 源码
有关详细示例,请查看 Examples/scripts/dpo.py 脚本。 在较高级别上,我们需要使用我们希望训练的模型、参考 ref_model 来初始化 DPOTrainer,我们将使用它来计算首选和拒绝响应的隐式奖励,beta 指隐式奖励的超参数, 数据集包含上面列出的 3 个条目。 请注意,模型和 ref_model 需要具有相同的架构(即仅解码器或编码器-解码器)。
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
之后就可以调用:
dpo_trainer.train()
请注意,β 是 DPO 损失的温度参数,通常在 0.1 到 0.5 范围内。 当beta -> 0 ,意味着忽略参考模型。
4、损失函数
给定偏好数据,我们可以根据 Bradley-Terry 模型拟合二元分类器,事实上,DPO 作者通过 Logsigmoid 提出标准化似然的 sigmoid 损失来拟合逻辑回归。
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities. Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
if reference_free:
ref_logratios = 0
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sigmoid":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto_pair":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
) chosen_rewards = (
self.beta
* (
policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
).detach()
)
rejected_rewards = (
self.beta
* (
policy_rejected_logps.to(self.accelerator.device)
- reference_rejected_logps.to(self.accelerator.device)
).detach()
) return losses, chosen_rewards, rejected_rewards
其他改进的损失函数:
RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。
IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。
cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。
KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。
5、指标:在训练和评估时,记录以下奖励指标:
rewards/chosen
: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by betarewards/rejected
: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by betarewards/accuracies
: mean of how often the chosen rewards are > than the corresponding rejected rewardsrewards/margins
: the mean difference between the chosen and corresponding rejected rewards
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {} (
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch) # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() return losses.mean(), metrics
DPO: Direct Preference Optimization 直接偏好优化(学习笔记)的更多相关文章
- KVM性能优化学习笔记
本学习笔记系列都是采用CentOS6.x操作系统,KVM虚拟机的管理也是采用virsh方式,网上的很多的文章都基于ubuntu高版本内核下,KVM的一些新的特性支持更好,本文只是记录了CentOS6. ...
- 深挖计算机基础:Linux性能优化学习笔记
参考极客时间专栏<Linux性能优化实战>学习笔记 一.CPU性能:13讲 Linux性能优化实战学习笔记:第二讲 Linux性能优化实战学习笔记:第三讲 Linux性能优化实战学习笔记: ...
- Pandas 性能优化 学习笔记
摘要 本文介绍了使用 Pandas 进行数据挖掘时常用的加速技巧. 实验环境 import numpy as np import pandas as pd print(np.__version__) ...
- mysql性能优化学习笔记(2)如何发现有问题的sql
一.使用mysql慢查询日志对有效率问题的sql进行监控 1)开启慢查询 show variables like ‘slow_query_log’;//查看是否开启慢查询日志 ...
- HIVE优化学习笔记
概述 之前写过关于hive的已经有两篇随笔了,但是作者依然还是一枚小白,现在把那些杂七杂八的总结一下,供以后查阅和总结.今天的文章介绍一下hive的优化.hive是好多公司都在使用的东西,也有好多大公 ...
- 燕十八MySQL优化学习笔记
观察 show status; 里面的这三个参数;Queries Threads_connected Threads_running判断周期性变化 -------------------------- ...
- mysql性能优化学习笔记-参数介绍及优化建议
MySQL服务器参数介绍 mysql参数介绍(客户端中执行),尽量只修改session级别的参数. 全局参数(新连接的session才会生效,原有已经连接的session不生效) set global ...
- mysql性能优化学习笔记
mysql性能优化 硬件对数据库的影响 CPU资源和可用内存大小 服务器硬件对mysql性能的影响 我们的应用是CPU密集型? 我们的应用的并发量如何? 数量比频率更好 64位使用32位的服务器版本 ...
- mysql优化学习笔记
优化sql的一般步骤 通过show status了解各种sql的执行频率 定位执行效率低的sql语句 通过explain分析效率低的sql 通过show profile分析sql 通过trace分析优 ...
- js性能优化--学习笔记
<高性能网站建设进阶指南>: 1.使用局部变量,避免深入作用域查找,局部变量是读写速度最快的:把函数中使用次数超过一次的对象属性和数组存储为局部变量是一个好方法:比如for循环中的.len ...
随机推荐
- 第8讲 browse命令的使用技巧
第8讲 browse命令的使用技巧 1.浏览所有parts,使用技巧 选中工程文件*.dsn/Edit/Browse/Parts.列出工程中用到的所有元件,方便在画完原理图后,查看哪些元件没有编号或数 ...
- 02、Linux 排查
Linux 分析排查 1.敏感文件信息 1.1.tmp 目录 /tmp:临时目录文件,每个用户都可以对它进行读写操作.因此一个普通用户可以对 /tmp 目录执行读写操作(ls -alt) 筛查 /tm ...
- LVS负载均衡(5)-- LVS持久连接
持久连接: 持久连接用于实现无论使用任何调度算法,在一段时间内(默认300s ),能够实现将来自同一个地址的请求始终发往同一个RS. 语法格式: ipvsadm -A|E -t|u|f service ...
- Ubuntu中安装OpenSSL
一.前期准备 1.1 压缩包下载 在安装openssl之前,我们需要下载对应的压缩包 https://www.openssl.org/source/openssl-3.0.1.tar.gz 此压缩包可 ...
- Javascript返回顶部和砸金蛋,跑马灯等游戏代码实现
1. 我们经常写页面的时候会遇到页面很长需要做返回顶部的操作:$("id /class").animate({scrollTop:$('.class').offset().top} ...
- grads 同时读取多个ctl文件方法
1.不同的文件进行不同的设置:'set dfile 2' 2.读取不同文件的变量:qv.2 实例如下:'reinit''open e:\tskt.CTL''open e:\uwnd.CTL''open ...
- PageOffice 6 给SaveFilePage指向的保存地址传参
PageOffice给保存方法传递参数的方式有两种: 通过设置保存地址的url中的?传递参数.例如: poCtrl.setSaveFilePage("/save?p1=1") 通过 ...
- kubernetes 之dashboard
部署 kubectl apply -f https://raw.githubusercontent.com/kubernetes/dashboard/v2.5.0/aio/deploy/recomme ...
- Django与前端框架协作开发实战:高效构建现代Web应用
title: Django与前端框架协作开发实战:高效构建现代Web应用 date: 2024/5/22 20:07:47 updated: 2024/5/22 20:07:47 categories ...
- HTML——table表格标签
一.table表格的完整写法 <!DOCTYPE html> <html> <head> <meta charset="utf-8"> ...