当前位置: 首页 > news >正文

强化学习框架:OpenRLHF源码解读,模型处理

本文主要介绍 强化学习框架:OpenRLHF源码解读,模型处理

models框架设计

了解一下 OpenRLHF的模型框架设计范式:

From:https://arxiv.org/pdf/2405.11143

可以知道一个大概的流程:输入Pormpt通过Actor model输出回复 Response,而后将两部分进行拼接再去由其他模型进行处理

1、actor.py

https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/actor.py

这部分主要为加载所需要的模型

class Actor(nn.Module):def __init__(...):if isinstance(pretrain_or_model, str):...self.model = model_class.from_pretrained(pretrain_or_model,trust_remote_code=True,attn_implementation=attn_implementation,quantization_config=nf4_config,torch_dtype=torch.bfloat16 if bf16 else "auto",device_map=device_map,)if lora_rank > 0:self.model.enable_input_require_grads()lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,r=lora_rank,lora_alpha=lora_alpha,target_modules=target_modules,lora_dropout=lora_dropout,bias="none",)self.model = get_peft_model(self.model, lora_config)...else:self.model = pretrain_or_model@torch.no_grad()def generate(self, input_ids: torch.Tensor, **kwargs):...sequences = self.model.generate(**generate_args)eos_token_id = generate_args["eos_token_id"]pad_token_id = generate_args["pad_token_id"]return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id)def forward(...):...output["logits"] = output["logits"].to(torch.float32) # 得到每一个token概率...log_probs = log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature)...action_log_probs = log_probs[:, -num_actions:]

这个actor比较简单,首先从huggingface加载需要的模型,并且对模型进行部分设置如:量化/lora微调。或者直接加载自己预训练好的模型。
1、generate:模块则是根据输入的内容(比如说被 tokenizer处理好的文本)input_ids通过模型输出新的内容(根据 **kwargs获取生成文本参数设置比如说:top_k等)
2、forward根据输入的 token 序列(sequences),计算模型在生成最后若干个 token(即 “动作”)时的对数概率(log probs),之所以要这么处理是因为,在强化学习模型中(PPO、DPO等)一般而言模型的输出是一个序列,但优化目标不是“能不能生成这个序列”,而是:这个序列中,哪些 token 是“好”的?模型对这些 token 的概率应该更高!比如说在 DPO中:

L ( θ ) = E [ m i n ( r ( θ ) ∗ A , c l i p ( r ( θ ) , 1 − ε , 1 + ε ) ∗ A ) ] L(θ) = E[ min(r(θ) * A, clip(r(θ), 1-ε, 1+ε) * A) ] L(θ)=E[min(r(θ)A,clip(r(θ),1ε,1+ε)A)]

里面的

r ( θ ) = π θ ( a ∣ s ) / π o l d ( a ∣ s ) r(\theta)=\pi_{\theta}(a|s)/\pi_{old}(a|s) r(θ)=πθ(as)/πold(as)

就是概率比值,上面代码中:

log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature)

计算的就是: l o g ( π θ ( a ∣ s ) ) log(\pi_{\theta}(a|s)) log(πθ(as)),在具体代码中:

def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:if temperature != 1.0:logits.div_(temperature)if logits.dtype in [torch.float32, torch.float64]:batch_dim = logits.shape[:-1]last_dim = logits.shape[-1]try:from flash_attn.ops.triton.cross_entropy import cross_entropy_lossoutput = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1))log_probs_labels = -output[0].view(*batch_dim)except ImportError:logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)logsumexp_values = _logsumexp_by_chunk(logits.reshape(-1, last_dim))logsumexp_values = logsumexp_values.view(*batch_dim)log_probs_labels = logits_labels - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)else:log_probs_labels = []for row_logits, row_labels in zip(logits, labels):  # loop to reduce peak mem consumptionrow_log_probs = F.log_softmax(row_logits, dim=-1)row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)log_probs_labels.append(row_log_probs_labels)log_probs_labels = torch.stack(log_probs_labels)return log_probs_labels

补充-1
在使用 AutoModelForCausalLM.from_pretrained使用得到 model之后,其支持输入参数为:

outputs = model(input_ids=None,            # 输入的token(batch_size, seq_length)attention_mask=None,       # 指示哪些 token 是有效的(非 padding),形状同 input_idsposition_ids=None,         # 位置编码past_key_values=None,inputs_embeds=None,use_cache=None,            # 是否使用k-v cachelabels=None,               # 输入标签就直接计算lossoutput_attentions=None,output_hidden_states=None,return_dict=None,
)

补充-2
在LLM训练过程中遇到过短的语句为了节约显存(如果都将内容补充到相同长度,那么就会有较多的padding造成浪费),因此可以将几个短的拼接起来,但是为了区分那些是一个句子那些不是的,在 OpenRLHF中通过参数:self.packing_samples。如果没有 packing那么直接根据 attention_mask将位置编码在处理一下

if not self.packing_samples:position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)
else:# convert attention_mask to position_idsif ring_attn_group is not None:labels = sequencessequences, attention_mask, position_ids = convert_ring_attn_params(sequences, attention_mask, packed_seq_lens, ring_attn_group)else:position_ids = reset_position_ids(attention_mask)# explicitly ignore attention_mask for packing_samplesattention_mask = None

其中 reset_position_ids做的就是重新做位置编码重新处理

2、model.py

https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/model.py

主要功能返回所需要的模型,主要返回2个模型:1、CriticModel;2、RewardModel 回顾一下这几类模型的作用:无论是在GRPO还是DPO中都会输出token然后需要去对token进行评分,起评分作用的就是 reward model 对应上面图中 reward model,除此之外都会计算 优势函数 Q ( s , a ) − V ( s ) Q(s,a)-V(s) Q(s,a)V(s))来评估策略的好坏优势函数里面计算就是通过 critic model来对某一个策略进行评估对应上面图像中的:value model

def _get_reward_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False):class RewardModel(base_pretrained_model):def __init__(...):...# 加载模型setattr(self, self.base_model_prefix, base_llm_model(config))self.value_head_prefix = value_head_prefixsetattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False) # 输出评分...def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, return_output=False, ring_attn_group=None,pad_sequence=False, packed_seq_lens=None,):...# 1、处理packingoutputs = getattr(self, self.base_model_prefix)(input_ids, attention_mask=attention_mask, position_ids=position_ids)last_hidden_states = outputs["last_hidden_state"]values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)...# 1、处理packingelse:# 输出最后一个有效token的评分代替整个句子评分eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)reward = values.gather(dim=1, index=eos_indices).squeeze(1)if not self.training and self.normalize_reward:reward = (reward - self.mean) / self.stdreturn (reward, outputs) if return_output else rewardreturn RewardModeldef _get_critic_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False):class CriticModel(base_pretrained_model):def __init__(...):...def forward(...):...# 1、处理packingoutputs = getattr(self, self.base_model_prefix)(input_ids, attention_mask=attention_mask, position_ids=position_ids)last_hidden_states = outputs["last_hidden_state"]values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)...if num_actions is None:assert return_outputreturn outputsif not self.packing_samples:action_values = values[:, -num_actions:]else:assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens)action_values = []offset = 0for num_action, seq_len in zip(num_actions, packed_seq_lens):start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1action_values.append(values[:, start:end])offset += seq_lenaction_values = torch.cat(action_values, dim=1)if return_output:return (action_values, outputs)else:return action_valuesreturn CriticModel

1、reward model: 传入一个 base_pretrained_model(比如 PreTrainedModel)、一个 base_llm_model(比如 AutoModel)以及一些控制参数。函数内部返回一个定制化的奖励模型类 RewardModel,它可以在给定输入句子时,输出一个数值(reward 分数),反映输出文本的质量。在forward计算中,直接将输入model使用的几个参数(见上面的补充有具体解释)计算最后取最后一个状态的值,并且将这个值取计算评分。也就是说 reward model:首先计算下一个预测的token而后对这些token进行打分
2、critic model:具体输入参数和 reward model相同。参考之前介绍,上面代码中直接返回action_values = values[:, -num_actions:]num_actions存在条件下)这样就会得到不同的Q(s, a1), Q(s, a2), …

总结上面两组模型,在 LLM 的强化学习场景下,Reward Model 和 Critic Model 都从 last_hidden_state 得到 token-level 表达,再用 Linear 层输出每个 token 的 score。

  • Reward Model 最后提取的是 EOS token 的 score,表示整句话的奖励。
  • Critic Model 会进一步提取最后 num_actions 个 token 的 value,这些 token 是 Actor 生成的动作,对应到 PPO 中的:𝐴(𝑠,𝑎)=𝑄(𝑠,𝑎)−𝑉(𝑠)。

理解上面内容,回顾最上面的框架设计,用下面例子进行解释。
Prompt:"The capital of France is"
Actor model:"Paris is beautiful"。那么合并得到:input_ids = ["The", "capital", "of", "France", "is", " Paris", "is", "beautiful"]
Reward model:对上面每个单词进行评分,假设:values = [0.1, 0.2, 0.3, 0.2, 0.4, 0.7, 0.5, 0.8] # 每个 token 的 score 而后输出句子中整体评分 0.8
Critic model:只对最后几个 token 的 action 计算 loss,于是:action_values = values[:, -3:] # 即取出最后 3 个生成 token 的 Q 值这些值也就对应了我们模型的生成

3、loss.py

https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py

补充-1:
裁剪使用的是torch.clamp(https://pytorch.org/docs/stable/generated/torch.clamp.html)强制将范围外的数值处理为边界值,范围内数字保持不变

1、PolicyLoss:Policy Loss for PPO

r t = exp ⁡ ( log ⁡ π ( a t ∣ s t ) − log ⁡ π old ( a t ∣ s t ) ) L clip ( t ) = min ⁡ ( r t ⋅ A t , clip ( r t , 1 − ϵ , 1 + ϵ ) ⋅ A t ) L policy = − E t [ L clip ( t ) ] \begin{align*} r_t &= \exp(\log \pi(a_t \mid s_t) - \log \pi_{\text{old}}(a_t \mid s_t)) \\ \mathcal{L}_{\text{clip}}(t) &= \min\left(r_t \cdot A_t,\ \text{clip}(r_t,\ 1 - \epsilon,\ 1 + \epsilon) \cdot A_t\right) \\ \mathcal{L}_{\text{policy}} &= -\mathbb{E}_t \left[ \mathcal{L}_{\text{clip}}(t) \right] \end{align*} rtLclip(t)Lpolicy=exp(logπ(atst)logπold(atst))=min(rtAt, clip(rt, 1ϵ, 1+ϵ)At)=Et[Lclip(t)]

2、ValueLoss: Value Loss for PPO

L value = 1 2 ⋅ E t ∼ mask [ max ⁡ ( ( V clip , t − R t ) 2 , ( V t − R t ) 2 ) ] 其中: V clip = V old + clip ( V − V old , − ϵ , ϵ ) \mathcal{L}_{\text{value}} = \frac{1}{2} \cdot \mathbb{E}_{t \sim \text{mask}} \left[ \max \left( (V_{\text{clip}, t} - R_t)^2, \, (V_t - R_t)^2 \right) \right]\\ \text{其中:}V_{\text{clip}} = V_{\text{old}} + \text{clip}(V - V_{\text{old}}, -\epsilon, \epsilon) Lvalue=21Etmask[max((Vclip,tRt)2,(VtRt)2)]其中:Vclip=Vold+clip(VVold,ϵ,ϵ)

代码测试

修改了代码见链接:https://www.big-yellow-j.top/_jupyter/OpenRLHF_model.py

总结

本文主要介绍了在 OpenRLHF中模型框架设计,主要分为3类模型:1、actor model;2、critic model;3、reward model这三类模型中分别起到作用:1、直接更具prompt输出response;2、输出token的评分(action_values = values[:, -3:]);3、返回整句输出评分(找出最后一个有效 token 的索引,然后从 value 向量中提取该位置的值作为 reward。)

相关文章:

  • IOT项目——DIY Weather Station With ESP32
  • 几种Word转换PDF的常用方法
  • Linux学习笔记2
  • 【前端】【业务逻辑】 数据大屏自适应方案汇总
  • 如何在idea里创建注释模版
  • MIT6.S081 - Lab9 File Systems(文件系统)
  • 【音视频】音频解码实战
  • nodejs使用require导入npm包,开发依赖和生产依赖 ,全局安装
  • 01.浏览器自动化webdriver源码分析之启动函数
  • Uniapp:navigator(页面跳转)
  • qt调用deepseek的API开发(附带源码)
  • Android Studio开发 SharedPreferences 详解
  • 【MATLAB第115期】基于MATLAB的多元时间序列的ARIMAX的预测模型
  • js原型链prototype解释
  • Nature Communications 面向形状可编程磁性软材料的数据驱动设计方法—基于随机设计探索与神经网络的协同优化框架
  • Qt绘制可选择范围的日历
  • 未来教育风向标 | 教育学顶流985高校,华东师范大学《AIGC技术赋能教育数字化转型的机遇与挑战》,13所大学deepseek
  • 深度解析MQTT源码架构与AIGC场景融合实战
  • 三生原理与现有密码学的核心区别?
  • 洗车小程序系统前端uniapp 后台thinkphp
  • 具身智能资本盛宴:3个月37笔融资,北上深争锋BAT下场,人形机器人最火
  • 王毅同英国外交大臣拉米通电话
  • 前瞻2025丨无糖茶,站在转折点?
  • 万斯偕印裔妻子访问印度,4天行程能否推进美印贸易谈判?
  • 哈萨克斯坦一名副市长遭枪击
  • 讲座|在数字化时代,“记住”到底意味着什么