当前位置: 首页 > news >正文

强化学习:高级策略梯度理论与优化方法

如果您想学习强化学习,我推荐David Sliver的讲座😊:RL Course by David Silver - Lecture 1: Introduction to Reinforcement Learning - YouTube

在本文开始前,如果您还没读过我的前一篇文章,由此进:

强化学习:基础理论与高级DQN算法及策略梯度基础-CSDN博客

自然策略梯度(NPG)与信息几何

1.策略空间的黎曼流式结构

  • 策略分布族:将策略参数空间视为统计流形 \mathcal{M}={\pi_\theta|\theta \in \Theta}

  • Fisher信息矩阵(黎曼度量张量):

F(\theta)=\mathbb{E}_{s\sim p^\pi,a\sim \pi_\theta}[\nabla_\theta\log \pi_\theta(a|s)\nabla_\theta\log\pi_\theta(a|s)^\top]

  • KL散度的局部近似(泰勒展开到二阶):

\text{KL}(\pi_\theta||\pi_{\theta+\Delta_\theta})\approx \frac{1}{2}\Delta\theta^\top F(\theta)\Delta\theta

2.自然梯度定义

  • 传统梯度方向在欧氏空间,自然梯度在黎曼空间:

\tilde{\nabla}_\theta J=F(\theta)^{-1}\nabla_\theta J

  • 最优更新方向证明:

求解带KL约束的优化问题:

\underset{\Delta\theta}{\max}J(\theta+\Delta\theta)\quad s.t. \quad \text{KL}(\pi_\theta||\pi_{\theta+\Delta\theta}) \leq \epsilon

通过拉格朗日乘子法得到自然梯度方向

3.自然策略梯度更新规则

\theta_{k+1}=\theta_k+\alpha F(\theta_k)^{-1}\nabla_\theta J(\theta_k)

实际计算技巧:

  • 使用共轭梯度法避免显示求逆

  • 增广矩阵法处理秩亏问题

兼容函数逼近定理

1.严格条件陈述

当价值函数逼近器Q_w(s,a)满足:

  1. 兼容性: \nabla_wQ_w(s,a)=\nabla_\theta\log\pi_\theta(a|s)

  2. 最小化均方误差:

w^*=\arg\underset{w}{\max}\mathbb{E}[(Q_w(s,a)-Q^\pi(s,a))^2]

则策略梯度估计无偏:

\nabla_\theta J(\theta)=\mathbb{E}[\nabla_\theta\log\pi_\theta(a|s)Q_w(s,a)]

2.证明概要

  • 条件1保证价值函数梯度与策略梯度在同一方向

  • 条件2保证 Q_wQ^\pi 在兼容子空间上的正交投影

  • 联合推导可得:\mathbb{E}[\nabla_\theta\log\pi_\theta(Q_w-Q^\pi)]=0

信任区域策略优化(TRPO)

1.核心目标与约束

优化问题:

\underset{\theta}{max}

s.t.\mathbb{E}_s[\text{KL}(\pi_{\theta_{old}}||\pi_\theta)(s)] \leq \delta

2.目标函数的局部近似

  • 优势函数近似(一阶泰勒展开):

L(\theta)\approx L(\theta_{old})+g^\top(\theta-\theta_{old})

其中 g=\nabla_\theta L|_{\theta=\theta_{old}}

  • KL散度的二阶近似:

\text{KL}(\theta_{old}||\theta)\approx \frac{1}{2}\Delta\theta^\top F(\theta_{old})\Delta\theta

F 是Fisher信息矩阵

3.解析解推导

通过拉格朗日乘子法得到最优更新方向:

\theta^*=\theta

自然梯度方向 F^{-1}g 在策略流形上是最速上升方向

4.实现中的共轭梯度法

求解 F^{-1}g 的步骤

  1. 计算Fisher-vector product:Fv=\mathbb{E}[\nabla(\nabla\log\pi)^\top v]

  2. 使用共轭梯度法迭代求解 Fx=g

  3. 通过回溯线搜索确保KL约束

近端策略优化(PPO)

1.剪切目标函数

L^{\text{CLIP}}(\theta)=\mathbb{E}_t[\min(r_t(\theta)A_t,clip(r_t(\theta),1-\epsilon,1+\epsilon)A_t)]

其中 r_t(\theta)=\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}

剪切区域分析:

  • A_t > 0,限制最大更新幅度为 (1 + \epsilon)A_t

  • A_t < 0,限制最小更新幅度为 (1-\epsilon)A_t

2.自适应KL惩罚项

目标函数:

L^{\text{KLPEN}}(\theta)=\mathbb{E}_t[r_t(\theta)A_t-\beta\text{KL}(\pi\theta_{old}||\pi_\theta)]

  • \beta 自适应规则:

\beta_{k+1}=\begin{cases}\beta_k/1.5 \quad if\:\text{KL}<\delta_{\text{low}}\\\beta_k \times2 \quad if\: \text{KL}>\delta_{\text{high}} \\\beta_k \quad \quad \:\:\:otherwise\end{cases}

典型设置:\delta_{\text{low}}=0.01, \delta_{\text{high}}=0.1

3.重要性采样方差控制

原始重要性权重方差:

\text{Var}(r_t)=\mathbb{E}[(\frac{\pi_\theta}{\pi_{old}}-1)^2]

剪切后的方差上界:

\text{Var}(r_t^{\text{clip}}\leq \epsilon^2\mathbb{E}[A_t^2])

直接偏好优化(DPO)

1.从奖励模型到策略的隐式转换

基于Bradley-Terry模型:

p^*(y_1 \succ y_2|x)=\frac{\exp(\beta\mathcal{R}(x,y_1))}{\exp(\beta\mathcal{R}(x,y_1))+\exp(\beta\mathcal{R}(x,y_2))}

关键替换:用策略表示奖励函数

\mathcal{R}(x,y)=\beta \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}+\beta\log Z(x)

2.目标函数推导

消去奖励函数后得到:

\mathcal{L}_\text{DPO}=-\mathbb{E}_(x,y_w,y_l)[\log \sigma(\beta\log\frac{\pi_\theta(y_w|x)}{\pi_\text{ref}(y_w|x)}-\beta\log\frac{\pi_\theta(y_l|x)}{\pi_\text{ref}(y_l|x)})]

其中 \sigma 是sigmoid函数

3.隐式KL约束分析

DPO等价于带动态约束的优化:

\underset{\theta}{\max}\mathbb{E}[\log\sigma(\beta\Delta\log\pi)]\quad s.t. \quad\text{KL}(\pi_\theta||\pi_\text{ref})\leq C

4.梯度分析

梯度计算公式:

\nabla_\theta \mathcal{L}_{\text{DPO}}=-\beta \mathbb{E}[\sigma(\hat{r_l}-\hat{r_w})(\nabla_\delta\log\pi_\theta(y_w|x))-\nabla_\delta\log\pi_\theta(y_l|x))]

其中  \hat{r_i}=\log\frac{\pi_\theta(y_i|x)}{\pi_{\text{ref}}(y_i|x)}

如果您对RL和测试时间扩展感兴趣,我自推这篇文章:从理论到实践:带你快速学习基于PRM的三种搜索方法-CSDN博客

相关文章:

  • leetcode110 平衡二叉树
  • 在QML中获取当前时间、IP和位置(基于网络请求)
  • Simple-BEV论文解析
  • module.noParse(跳过指定文件的依赖解析)
  • [贪心_8] 跳跃游戏 | 单调递增的数字 | 坏了的计算器
  • GitOps进化:深入探讨 Argo CD 及其对持续部署的影响
  • 青少年编程与数学 02-018 C++数据结构与算法 12课题、递归
  • 多模态大语言模型arxiv论文略读(四十二)
  • Dify框架面试内容整理-Dify如何实现模型调用与管理?
  • 【OSG学习笔记】Day 10: 字体与文字渲染(osgText)
  • 两台没有网络的电脑如何通过网线共享传输文件
  • Compose笔记(十八)--rememberLazyListState
  • 【第11节 嵌入式软件的组成】
  • 从后端研发角度出发,使用k8s部署业务系统
  • ARP协议【复习篇】
  • Tortoise-ORM级联查询与预加载性能优化
  • Nacos简介—3.Nacos的配置简介
  • 如何修改npm的全局安装路径?
  • 冲刺一区!挑战7天一篇文献计量学SCI DAY1-7
  • 机器之眼megauging(工业机器视觉软件)是否开源?
  • 从“高阶智驾”到“辅助驾驶”,上海车展上的“智驾”宣发变调
  • 今年地质灾害防治形势严峻,哪些风险区被自然资源部点名?
  • 上海4-6月文博美展、剧目演出不断,将开设直播推出文旅优惠套餐
  • 习近平在气候和公正转型领导人峰会上的致辞(全文)
  • 中国和阿塞拜疆签署互免签证协定
  • 国防部发布、中国军号及多家央媒官博发祝福海报:人民海军76岁生日快乐