当前位置: 首页 > news >正文

基于 Python 的自然语言处理系列(85):PPO 原理与实践

📌 本文介绍如何在 RLHF(Reinforcement Learning with Human Feedback)中使用 PPO(Proximal Policy Optimization)算法对语言模型进行强化学习微调。

🔗 官方文档:trl PPOTrainer

一、引言:PPO 在 RLHF 中的角色

        PPO(Proximal Policy Optimization)是一种常用的强化学习优化算法,它在 RLHF 的第三阶段发挥核心作用:通过人类偏好训练出的奖励模型对语言模型行为进行优化。我们将在本篇中详细介绍如何基于 Hugging Face 的 trl 库,结合 IMDb 数据集、情感分析奖励模型,完成完整的 PPO 训练流程。

二、环境依赖

pip install peft trl accelerate datasets transformers

三、配置 PPOConfig

from trl import PPOConfigppo_config = PPOConfig(model_name="lvwerra/gpt2-imdb",query_dataset="imdb",reward_model="sentiment-analysis:lvwerra/distilbert-imdb",learning_rate=1.41e-5,log_with=None,mini_batch_size=128,batch_size=128,target_kl=6.0,kl_penalty="kl",seed=0,
)

四、构建数据集与 Tokenizer

from datasets import load_dataset
from transformers import AutoTokenizer
from trl.core import LengthSamplerdef build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):tokenizer = AutoTokenizer.from_pretrained(config.model_name, use_fast=True)tokenizer.pad_token = tokenizer.eos_tokends = load_dataset(query_dataset, split="train")ds = ds.rename_columns({"text": "review"})ds = ds.filter(lambda x: len(x["review"]) > 200)input_size = LengthSampler(input_min_text_length, input_max_text_length)def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]sample["query"] = tokenizer.decode(sample["input_ids"])return sampleds = ds.map(tokenize)ds.set_format(type="torch")return dsdataset = build_dataset(ppo_config, ppo_config.query_dataset)

五、加载模型与参考模型(Ref Model)

from trl import AutoModelForCausalLMWithValueHeadmodel_cls = AutoModelForCausalLMWithValueHead
model = model_cls.from_pretrained(ppo_config.model_name)
ref_model = model_cls.from_pretrained(ppo_config.model_name)tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

六、构建 PPOTrainer 与奖励模型

from trl import PPOTrainer
from transformers import pipelinedef collator(data):return dict((key, [d[key] for d in data]) for key in data[0])ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

构建情感奖励模型

task, model_name = ppo_config.reward_model.split(":")
sentiment_pipe = pipeline(task, model=model_name, device=1 if torch.cuda.is_available() else "cpu", return_all_scores=True, function_to_apply="none", batch_size=16
)# 确保 tokenizer 设置 pad_token_id
sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id
sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id

七、执行 PPO 训练循环

 
from tqdm.auto import tqdm
import torchgeneration_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": tokenizer.eos_token_id,"max_new_tokens": 32,
}for step, batch in enumerate(tqdm(ppo_trainer.dataloader)):query_tensors = batch["input_ids"]response_tensors, ref_response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs)batch["response"] = tokenizer.batch_decode(response_tensors)batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)texts = [q + r for q, r in zip(batch["query"], batch["response"])]rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(texts)]ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]ref_rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(ref_texts)]batch["ref_rewards"] = ref_rewardsstats = ppo_trainer.step(query_tensors, response_tensors, rewards)ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"])

八、总结与展望

        在本篇文章中,我们实现了以下核心步骤:

阶段描述
数据构建利用 IMDb 构造简短语料用于语言生成
模型构建加载 GPT2 并构建 Value Head 以评估奖励
奖励模型使用 DistilBERT 进行情感打分作为奖励信号
PPO 训练利用 TRL 中的 PPOTrainer 实现语言强化优化

        PPO 是 RLHF 中至关重要的一环,在人类反馈基础上不断微调模型的输出质量,是当前 ChatGPT、Claude 等大模型背后的关键技术之一。

        📘 下一篇预告:《基于 Python 的自然语言处理系列(86):DPO(Direct Preference Optimization)原理与实战》
        相比传统 RLHF 流程,DPO 提供了一种更简洁、无需奖励模型与 PPO 的替代方案,敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关文章:

  • 70.评论日记
  • Kubernetes in action-初相识
  • C++ 类及函数原型详解
  • 通过模仿学习实现机器人灵巧操作:综述(上)
  • 船舶参数(第一版)
  • 交叉熵损失函数:从信息量、熵、KL散度出发的推导与理解
  • 动态规划算法详解(C++)
  • 使用Tortoise-ORM和FastAPI构建评论系统
  • RDK X3新玩法:超沉浸下棋机器人开发日记
  • 通过VSCode远程连接到CentOS7/Ubuntu18等老系统
  • 单精度浮点运算/定点运算下 MATLAB (VS) VIVADO
  • 【大语言模型】大语言模型(LLMs)在工业缺陷检测领域的应用
  • AD相同网络的铜皮和导线连接不上
  • 泽众TestOne精准测试:助力软件开发质量新升级
  • VS Code搭建C/C++开发环境
  • 设置Rocky Linux盒盖不休眠的3个简单步骤
  • 第TR5周:Transformer实战:文本分类
  • MySQL 表结构及日志文件详解
  • 树莓派4B+Ubuntu24.04 电应普超声波传感器串口输出 保姆级教程
  • 国产AI大模型超深度横评:技术参数全解、商业落地全场景拆解
  • 《深化养老服务改革发展的大湾区探索》新书将于今年6月出版
  • 杨荫凯已任浙江省委常委、组织部部长
  • 全国首例!上市公司董监高未履行公开增持承诺,投资者起诉获赔
  • 韩国对华中厚板征收临时反倾销税
  • 为什么猛起身会头晕?你的身体在发出这个警报
  • 东部战区新闻发言人就美“劳伦斯”号导弹驱逐舰过航台湾海峡发表谈话