基于 Python 的自然语言处理系列(86):DPO(Direct Preference Optimization)原理与实战
✨ 本文是 RLHF 系列的延续,介绍 Hugging Face
trl
库中的 DPOTrainer 的使用方法与原理,帮助你理解如何使用直接偏好优化方法(Direct Preference Optimization, DPO)进行大语言模型偏好微调。
1. 什么是 DPO?
传统 RLHF 流程包括三个阶段:有监督微调(SFT)、奖励模型训练(RM)与强化学习(PPO)。而 DPO(Direct Preference Optimization)提出了一种 无需显式奖励模型与价值函数 的替代方案:
-
假设模型本身隐式表示了奖励函数;
-
通过比较“优选(chosen)”与“被拒(rejected)”的响应,在 KL 约束下最大化偏好概率差异;
-
更易于训练和部署,显著简化 RLHF 流程。
论文链接:Direct Preference Optimization
官方实现:trl.dpo_trainer
2. 准备工作与依赖安装
pip install transformers datasets trl peft
设置 CUDA 设备:
import os
import torchos.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3. 构造符合格式的数据集
DPO 需要每条样本包含 3 个字段:
-
prompt
: 用户输入; -
chosen
: 优选回答; -
rejected
: 被拒回答。
可手动构造或使用已有偏好数据集:
dpo_dataset_dict = {"prompt": ["hello","how are you","What is your name?","Which is the best programming language?",],"chosen": ["hi nice to meet you","I am fine","My name is Mary","Python",],"rejected": ["leave me alone","I am not fine","I don't have a name","C++",],
}
4. 加载预训练模型和 Tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizermodel_name_or_path = "gpt2"model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model_ref = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)if tokenizer.pad_token is None:tokenizer.pad_token = tokenizer.eos_token
5. 加载标准偏好数据集(Anthropic HH)
from datasets import load_dataset
from typing import Dictdef extract_anthropic_prompt(text):search = "\n\nAssistant:"idx = text.rfind(search)assert idx != -1, f"Missing assistant tag"return text[:idx + len(search)]def get_hh(split="train", sanity_check=True) -> torch.utils.data.Dataset:ds = load_dataset("Anthropic/hh-rlhf", split=split)if sanity_check:ds = ds.select(range(min(len(ds), 1000)))def convert(sample):prompt = extract_anthropic_prompt(sample["chosen"])return {"prompt": prompt,"chosen": sample["chosen"][len(prompt):],"rejected": sample["rejected"][len(prompt):],}return ds.map(convert)
6. 配置训练参数
from transformers import TrainingArgumentstraining_args = TrainingArguments(per_device_train_batch_size=8,max_steps=1000,gradient_accumulation_steps=1,learning_rate=1e-3,evaluation_strategy="steps",eval_steps=500,logging_steps=5,logging_first_step=True,warmup_steps=150,output_dir="./dpo_output",bf16=True,optim="rmsprop",remove_unused_columns=False,
)
7. 初始化 DPOTrainer 并训练模型
from trl import DPOTrainerdpo_trainer = DPOTrainer(model=model,ref_model=model_ref,args=training_args,beta=0.1, # KL 控制项系数train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,max_length=512,max_prompt_length=128,max_target_length=128,generate_during_eval=True,
)
开始训练:
dpo_trainer.train()
8. DPO 相较于 PPO 的优势总结
维度 | PPO | DPO |
---|---|---|
是否需 Value Head | ✅ 需要 | ❌ 不需要 |
奖励函数 | 外部 RM | 隐式建模 |
算法复杂性 | 较高 | 简洁 |
模型要求 | AutoModelForCausalLMWithValueHead | AutoModelForCausalLM |
收敛速度 | 慢 | 快 |
9. 总结与展望
通过本文,我们完成了 DPO 从原理到实战的全过程实现,涵盖了:
-
✅ 数据格式构造
-
✅ 偏好数据加载与转换
-
✅ 模型加载与参考模型初始化
-
✅
DPOTrainer
调用与训练过程 -
✅ 与 PPO 的结构性比较
📌 DPO 在 RLHF 中极具实用价值,尤其适用于资源受限或对部署复杂性要求较低的场景。
🔜 下一篇预告:《基于 Python 的自然语言处理系列(87):RRHF》
如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!
欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。
谢谢大家的支持!