当前位置: 首页 > news >正文

基于 Python 的自然语言处理系列(86):DPO(Direct Preference Optimization)原理与实战

✨ 本文是 RLHF 系列的延续,介绍 Hugging Face trl 库中的 DPOTrainer 的使用方法与原理,帮助你理解如何使用直接偏好优化方法(Direct Preference Optimization, DPO)进行大语言模型偏好微调。

1. 什么是 DPO?

        传统 RLHF 流程包括三个阶段:有监督微调(SFT)、奖励模型训练(RM)与强化学习(PPO)。而 DPO(Direct Preference Optimization)提出了一种 无需显式奖励模型与价值函数 的替代方案:

  • 假设模型本身隐式表示了奖励函数;

  • 通过比较“优选(chosen)”与“被拒(rejected)”的响应,在 KL 约束下最大化偏好概率差异;

  • 更易于训练和部署,显著简化 RLHF 流程。

论文链接:Direct Preference Optimization
官方实现:trl.dpo_trainer

2. 准备工作与依赖安装

pip install transformers datasets trl peft

设置 CUDA 设备:

import os
import torchos.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. 构造符合格式的数据集

        DPO 需要每条样本包含 3 个字段:

  • prompt: 用户输入;

  • chosen: 优选回答;

  • rejected: 被拒回答。

        可手动构造或使用已有偏好数据集:

dpo_dataset_dict = {"prompt": ["hello","how are you","What is your name?","Which is the best programming language?",],"chosen": ["hi nice to meet you","I am fine","My name is Mary","Python",],"rejected": ["leave me alone","I am not fine","I don't have a name","C++",],
}

4. 加载预训练模型和 Tokenizer

from transformers import AutoModelForCausalLM, AutoTokenizermodel_name_or_path = "gpt2"model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model_ref = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)if tokenizer.pad_token is None:tokenizer.pad_token = tokenizer.eos_token

5. 加载标准偏好数据集(Anthropic HH)

from datasets import load_dataset
from typing import Dictdef extract_anthropic_prompt(text):search = "\n\nAssistant:"idx = text.rfind(search)assert idx != -1, f"Missing assistant tag"return text[:idx + len(search)]def get_hh(split="train", sanity_check=True) -> torch.utils.data.Dataset:ds = load_dataset("Anthropic/hh-rlhf", split=split)if sanity_check:ds = ds.select(range(min(len(ds), 1000)))def convert(sample):prompt = extract_anthropic_prompt(sample["chosen"])return {"prompt": prompt,"chosen": sample["chosen"][len(prompt):],"rejected": sample["rejected"][len(prompt):],}return ds.map(convert)

6. 配置训练参数

from transformers import TrainingArgumentstraining_args = TrainingArguments(per_device_train_batch_size=8,max_steps=1000,gradient_accumulation_steps=1,learning_rate=1e-3,evaluation_strategy="steps",eval_steps=500,logging_steps=5,logging_first_step=True,warmup_steps=150,output_dir="./dpo_output",bf16=True,optim="rmsprop",remove_unused_columns=False,
)

7. 初始化 DPOTrainer 并训练模型

from trl import DPOTrainerdpo_trainer = DPOTrainer(model=model,ref_model=model_ref,args=training_args,beta=0.1,  # KL 控制项系数train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,max_length=512,max_prompt_length=128,max_target_length=128,generate_during_eval=True,
)

开始训练:

dpo_trainer.train()

8. DPO 相较于 PPO 的优势总结

维度PPODPO
是否需 Value Head✅ 需要❌ 不需要
奖励函数外部 RM隐式建模
算法复杂性较高简洁
模型要求AutoModelForCausalLMWithValueHeadAutoModelForCausalLM
收敛速度

9. 总结与展望

        通过本文,我们完成了 DPO 从原理到实战的全过程实现,涵盖了:

  • ✅ 数据格式构造

  • ✅ 偏好数据加载与转换

  • ✅ 模型加载与参考模型初始化

  • DPOTrainer 调用与训练过程

  • ✅ 与 PPO 的结构性比较

📌 DPO 在 RLHF 中极具实用价值,尤其适用于资源受限或对部署复杂性要求较低的场景。

🔜 下一篇预告:《基于 Python 的自然语言处理系列(87):RRHF》

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关文章:

  • 【信息系统项目管理师】高分论文:论质量管理和进度管理(智慧旅游平台建设项目)
  • HBase协处理器深度解析:原理、实现与最佳实践
  • 基于FFmpeg命令行的实时图像处理与RTSP推流解决方案
  • 使用java代码注册onloyoffice账号 || 注册onloyoffice账号
  • vue中 vue.config.js反向代理
  • 计算机网络 | 应用层(3)-- 因特网中的电子邮件
  • 使用银行卡二要素API让支付更加安心
  • 北斗导航 | Transformer增强BiLSTM网络的GNSS伪距观测量误差探测
  • B. And It‘s Non-Zero
  • 提示词的神奇魔力——如何通过它改变AI的输出
  • 免费送源码:Java+ssm+HTML 三分糖——甜品店网站设计与实现 计算机毕业设计原创定制
  • springboot + mybatis 需要写 .xml吗
  • Java—— 五道算法水题
  • 力扣热题100题解(c++)—链表
  • 架构师备考-设计模式23种及其记忆特点
  • 【虚幻C++笔记】碰撞检测
  • 指标监控:Prometheus 结合 Grafana,监控redis、mysql、springboot程序等等
  • 一文详解Adobe Photoshop 2025安装教程
  • Springboot集成SSE实现消息推送+RabbitMQ解决集群环境下SSE通道跨节点事件推送问题
  • 【BBDM】main.py -- notes
  • “冲刺万亿城市”首季表现如何?温州领跑,大连GDP超徐州
  • 四川公布一起影视盗版案例:1个网站2人团伙盗售30多万部
  • 石磊当选河北秦皇岛市市长
  • 西安市优化营商环境投诉举报监督平台上线,鼓励实名检举控告
  • “仅退款”将成历史?电商平台集中调整售后规则
  • 广电总局加快布局超高清视听产业链,多项成果亮相