当前位置: 首页 > news >正文

【速写】hook与fx

文章目录

  • 问题
  • 方法
    • 方法 1:使用 PyTorch 的 `register_forward_hook`
    • 方法 2:自定义前向传播(修改 `forward` 方法)
    • 方法 3:使用 `output_attentions` 或 `output_hidden_states`
    • 方法 4:使用 `torch.fx` 进行动态追踪
    • 总结
  • 前向钩子的输出到底是什么?
  • 反向钩子应该怎么用?
  • 更好的写法:钩子与装饰器


最近师弟要跑步,目标是减肥(大概跟我差不多高,但比我要重40斤),想着下半年能跑个半马(不过以他目前的水平来说,还是有点差距的,虽然他只是想能在关门前跑完就行了)。

有半个月了,一周跑5天,周日休息,周四下午跟他去打一个小时羽毛球,大概6-7分的配,一天跑2-3km,有点找回快乐跑步的初衷。但实在还是太慢了,不过也无所大谓。

讨论了一些问题,我很惊异于他们做diffuser的,居然融合lora就直接查一下lora权重的绝对值,然后取top-k大的保留,其余的直接dropout掉了

目前diffuser方法有很强的解释性,很多研究都表明了模型的不同层是可以近似对应图层进行解耦的,即比如第1层控制的是layout,第2层控制的是color,第3层控制的是mainbody,因此如果只是去调颜色,就只需要对第2层套lora(我之前以为这个事情目前不太好实现,其实就是target_modules里的名字写得不对,不能按照key或者model.named_modules()来写,要省掉前面几项才行),不过也可以用笨方法冻结:

# 先对所有 q_proj 应用 LoRA
peft_config = LoraConfig(target_modules=["q_proj"])# 然后手动冻结不需要的层
for name, param in model.named_parameters():if "layers.1." in name:  # 例如冻结第1层及之后的参数param.requires_grad = False

问题

这个不是重点,我的观点是,权重的数值分布并不重要( W W W不重要),输出的值分布才是更重要的( W x Wx Wx才是重要的),或者说激活后的输出才是最重要的( L a y e r N o r m ( σ ( W x + b ) ) LayerNorm(\sigma(Wx+b)) LayerNorm(σ(Wx+b))),上个月何凯明发的用tanh替代LayerNorm的那个神文里,说的也是这个道理,其实无所谓到底发生了什么,输出的分布近似tanh,那就别搞什么LayerNorm了。

那么这又是一个老生常谈的问题:

使用transformers调用DeepSeek-R1-Distill-Qwen-32B模型,可以用model.generate方法输出模型的结果,也可以使用model.forward(即model.__call__)调用,现在的问题是,如果想知道模型中每一层(比如说,每一个self_attn的q_proj,k_proj的输出结果),那么我有办法实现吗,在不修改transformers源码的情况下。

以下是几种可行的方法:

  1. 法1:最为常用,就是要学会用钩子,其实这是个很好的习惯,自己写模型的时候也可以提前把钩子写好,以前觉得什么register_bufferregister_hook没有意义,并非如此。
  2. 法2:这个很奇特,不太会写
  3. 法3:传统方案,不管用
  4. 法4:用torch.fx,其实这个本质上就是修改了源码

方法

方法 1:使用 PyTorch 的 register_forward_hook

你可以为模型的某个子模块(如 q_projk_proj)注册一个钩子,在推理时捕获它们的输出。

示例代码

import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 存储各层输出的字典
outputs = {}def get_layer_output_hook(name):def hook(module, input, output):outputs[name] = outputreturn hook# 注册钩子(以第0层的self_attn的q_proj为例)
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
hook = q_proj.register_forward_hook(get_layer_output_hook(f"layer_{layer_idx}_q_proj"))# 运行推理
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():model.generate(**inputs, max_new_tokens=5)# 查看输出
print(outputs)  # 包含 layer_0_q_proj 的输出
hook.remove()  # 移除钩子

适用场景

  • 适用于 单次推理 时获取特定层的输出。
  • 可以注册多个钩子,捕获不同层的输出。

方法 2:自定义前向传播(修改 forward 方法)

如果你希望更灵活地控制前向传播过程,可以 临时替换 某些模块的 forward 方法,以记录中间结果。

示例代码

original_forward = None
captured_outputs = []def custom_forward(module, *args, **kwargs):global original_forwardoutput = original_forward(*args, **kwargs)captured_outputs.append(output)return output# 替换第0层 self_attn.q_proj 的 forward 方法
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
original_forward = q_proj.forward
q_proj.forward = lambda *args, **kwargs: custom_forward(q_proj, *args, **kwargs)# 运行推理
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():model.generate(**inputs, max_new_tokens=5)# 恢复原始 forward 方法
q_proj.forward = original_forwardprint(captured_outputs)  # 包含 q_proj 的输出

适用场景

  • 适用于需要 更灵活控制 的情况,比如在训练时动态修改某些层的行为。
  • 比钩子更灵活,但需要手动管理 forward 方法的替换和恢复。

方法 3:使用 output_attentionsoutput_hidden_states

某些 HuggingFace 模型支持直接返回注意力权重或隐藏状态(但可能不适用于所有层或所有模型)。

示例代码

inputs = tokenizer("Hello, how are you?", return_tensors="pt").to(model.device)
outputs = model(**inputs,output_attentions=True,  # 返回注意力权重output_hidden_states=True,  # 返回所有隐藏状态
)# 获取第0层的隐藏状态
hidden_states = outputs.hidden_states[0]
print(hidden_states.shape)  # [batch_size, seq_len, hidden_dim]# 获取第0层的注意力权重
attentions = outputs.attentions[0]
print(attentions.shape)  # [batch_size, num_heads, seq_len, seq_len]

适用场景

  • 适用于 标准 Transformer 结构,但可能无法获取 q_proj/k_proj 的中间输出。
  • 适用于快速获取 注意力权重隐藏状态

方法 4:使用 torch.fx 进行动态追踪

如果你需要 更复杂的中间结果提取,可以使用 PyTorch 的 torch.fx 动态追踪计算图。

示例代码

from torch.fx import symbolic_trace# 追踪模型的计算图
traced_model = symbolic_trace(model)# 自定义提取中间结果
class ExtractModuleOutputs(torch.nn.Module):def __init__(self, model):super().__init__()self.model = modeldef forward(self, *args, **kwargs):# 在这里手动提取中间结果return self.model(*args, **kwargs)# 运行推理
extractor = ExtractModuleOutputs(traced_model)
output = extractor(**inputs)

简单一个例子,比如还是查看第1层的q_proj的输出:

import torch
from torch.fx import symbolic_trace
from transformers import AutoModelForCausalLM, AutoTokenizer# 加载模型和分词器
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 定义一个包装类,用于提取中间输出
class ExtractModuleOutputs(torch.nn.Module):def __init__(self, model, target_module_path):super().__init__()self.model = modelself.target_module_path = target_module_pathself.captured_output = None# 使用 torch.fx 追踪计算图self.traced_model = symbolic_trace(self.model)# 在计算图中插入代码,捕获目标模块的输出self._insert_capture_code()def _insert_capture_code(self):# 找到目标模块(例如 "model.layers.0.self_attn.q_proj")target_module = eval(f"self.model.{self.target_module_path}")# 定义一个钩子函数,用于捕获输出def capture_output(module, input, output):self.captured_output = outputreturn output  # 不影响原始计算# 注册钩子target_module.register_forward_hook(capture_output)def forward(self, *args, **kwargs):# 运行追踪后的模型self.captured_output = None  # 清空之前的捕获output = self.traced_model(*args, **kwargs)return output# 初始化提取器,指定目标模块路径
extractor = ExtractModuleOutputs(model,target_module_path="model.layers[0].self_attn.q_proj"
)# 运行推理并捕获输出
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():logits = extractor(**inputs)  # 正常前向传播q_proj_output = extractor.captured_output  # 提取的 q_proj 输出print("q_proj output shape:", q_proj_output.shape)  # 例如 [batch_size, seq_len, hidden_dim]

用的还是hook来找的,不过你也可以直接把源码复制一份扔到forward里,然后手动logging一下想要的输出,不过这个事情似乎并不是很好实现,因为源码是高度解耦的,不太好直接并起来。

适用场景

  • 适用于 复杂模型,需要 动态修改计算图 的情况。
  • 学习成本较高,但灵活性最强。

总结

方法适用场景灵活性实现难度
钩子 (register_forward_hook)单次推理,获取中间输出中等
自定义 forward需要动态修改前向逻辑
output_attentions/output_hidden_states获取标准隐藏状态/注意力
torch.fx复杂模型,动态修改计算图最高

如果你的目标是 简单获取 q_proj/k_proj 的输出,推荐 方法 1(钩子),因为它不需要修改模型结构,且易于管理。


前向钩子的输出到底是什么?

在方法1中,一般来讲q_proj的钩子所捕获的输出应该是 Q x Qx Qx吧,那加入权重 Q Q Q被调用了两次,比方说,有 Q x 1 Qx_1 Qx1 Q x 2 Qx_2 Qx2,那么outputs里就会有对应的两个输出吗?

deepseek认为我的理解是正确的!在 方法 1(钩子) 中,每次目标模块(如 q_proj)被调用时,钩子都会捕获它的输出。因此,如果 q_proj一次完整的前向传播中被调用了两次(例如计算 (Qx_1) 和 (Qx_2)),那么 outputs 字典中会 按调用顺序存储这两个输出


关键验证实验

我们可以通过一个简单的例子验证这一点。假设 q_proj 在模型中被多次调用(例如,由于多轮注意力计算或不同的输入路径):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 存储所有调用的输出
outputs = []def hook(module, input, output):outputs.append(output)return output  # 不影响原始计算# 注册钩子到第0层的 q_proj
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
hook_handle = q_proj.register_forward_hook(hook)# 模拟输入(假设模型内部会多次调用 q_proj)
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)with torch.no_grad():model(**inputs)  # 正常前向传播# 检查 outputs
print(f"Number of q_proj calls: {len(outputs)}")
for i, out in enumerate(outputs):print(f"Call {i+1} output shape: {out.shape}")# 移除钩子
hook_handle.remove()

可能的结果

  1. 如果 q_proj单次前向传播中被调用一次(例如标准的自注意力机制),outputs 会包含 1 个张量(形状如 [batch_size, seq_len, hidden_dim])。
  2. 如果 q_proj 被调用多次(例如模型有 多头注意力分块处理),outputs 会按调用顺序存储所有输出。

为什么会出现多次调用?

  1. 多头注意力(Multi-Head Attention)

    • 某些实现中,q_proj 可能被拆分为多个头(heads),但通常仍是一次性计算 (Qx)。
    • 如果每个头独立调用 q_proj(罕见),则钩子会捕获多次输出。
  2. 分块处理(Chunking)

    • 长序列可能被拆分为块(chunks),导致 q_proj 被多次调用(例如 x1x2 是两块输入)。
  3. 自定义模型逻辑

    • 如果模型在 forward 中显式调用了多次 q_proj(例如某些动态路由机制),钩子会捕获所有调用。

如何区分不同调用的来源?
如果需要明确每次调用对应的输入(例如区分 (Qx_1) 和 (Qx_2)),可以:

  1. 检查输入参数(通过钩子的 input 参数):
    def hook(module, input, output):print(f"Input shape: {input[0].shape}")  # input 是元组,取第一个元素outputs.append(output)
    
  2. 记录调用堆栈(适用于调试):
    import traceback
    def hook(module, input, output):print(f"Call stack:\n{traceback.format_stack()[-3]}")outputs.append(output)
    

总结

  • 钩子会捕获目标模块的所有调用。如果 q_proj 被调用 (N) 次,outputs 会存储 (N) 个输出。
  • 典型情况:在标准 Transformer 中,q_proj 通常 每层每前向传播仅调用一次(计算 (Qx))。
  • 异常情况:如果模型有特殊设计(如分块、动态路由),可能会多次调用,此时需结合输入分析。

如果你的目标是 仅捕获第一次调用,可以在钩子中设置条件逻辑(例如 if len(outputs) == 0)。


反向钩子应该怎么用?

前面的方法针对的是前向钩子,用于捕获推理时的输出值,但是反向钩子同样重要,一般用于捕获梯度。

在 PyTorch 中,register_backward_hook 用于在反向传播(即梯度计算)时捕获模块的 输入梯度、输出梯度权重梯度。它的使用方式与 register_forward_hook 类似,但关注的是反向传播过程。以下是详细说明和示例:


1. register_backward_hook 的基本用法

钩子函数的签名如下:

def backward_hook(module, grad_input, grad_output) -> Tensor or None:# grad_input: 输入的梯度(通常是一个元组,对应 forward 输入的梯度)# grad_output: 输出的梯度(通常是一个元组,对应 forward 输出的梯度)# 可以返回修改后的梯度(可选)return modified_grad_input

关键点

  • grad_input:是 forward 输入的梯度(例如 x 的梯度),格式为 Tuple[Tensor]
  • grad_output:是 forward 输出的梯度(例如 y 的梯度),格式为 Tuple[Tensor]
  • 返回值:可以修改 grad_input 并返回(例如梯度裁剪),如果不需要修改则返回 None

2. 示例:捕获 q_proj 的反向梯度

假设我们想捕获 q_proj 在反向传播时的 输入梯度输出梯度

import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 存储梯度信息
gradients = {"grad_input": None,"grad_output": None
}def backward_hook(module, grad_input, grad_output):gradients["grad_input"] = grad_input  # 保存输入的梯度gradients["grad_output"] = grad_output  # 保存输出的梯度return None  # 不修改梯度# 注册反向钩子到第0层的 q_proj
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
hook_handle = q_proj.register_backward_hook(backward_hook)# 准备输入并计算损失
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model(**inputs, labels=inputs["input_ids"])  # 需要 labels 以计算 loss
loss = outputs.loss# 反向传播(触发钩子)
loss.backward()# 检查梯度
print("grad_input (q_proj 输入的梯度):", [g.shape for g in gradients["grad_input"] if g is not None])
print("grad_output (q_proj 输出的梯度):", [g.shape for g in gradients["grad_output"] if g is not None])# 移除钩子
hook_handle.remove()

输出解释

  • grad_input:是 q_proj.forward 的输入(即 x)的梯度,形状通常与 x 相同。
  • grad_output:是 q_proj.forward 的输出(即 Qx)的梯度,形状通常与 Qx 相同。

3. 常见用途

(1) 梯度裁剪(Gradient Clipping)

def gradient_clip_hook(module, grad_input, grad_output):max_norm = 1.0grad_input = tuple(torch.clamp(g, -max_norm, max_norm) for g in grad_input)return grad_input  # 返回修改后的输入梯度

(2) 梯度检查(Debugging)

def debug_hook(module, grad_input, grad_output):print(f"Module: {module.__class__.__name__}")print(f"grad_input: {[g.shape for g in grad_input]}")print(f"grad_output: {[g.shape for g in grad_output]}")return None

(3) 梯度修改(自定义优化策略)

def custom_grad_hook(module, grad_input, grad_output):# 例如:对梯度加权grad_input = tuple(g * 0.5 for g in grad_input)return grad_input

4. 注意事项

  1. 梯度顺序

    • grad_input 的顺序与 forward 的输入顺序一致,但可能包含 None(例如未计算梯度的输入)。
    • 对于 Linear 层(如 q_proj),grad_input 通常是 (x的梯度, weight的梯度, bias的梯度)
  2. 内存占用

    • 反向钩子会保存梯度张量,可能增加内存消耗,调试后需及时移除钩子:
      hook_handle.remove()  # 必须手动移除!
      
  3. 动态图 vs 静态图

    • torch.no_grad()eval() 模式下,反向钩子不会被触发。
    • 如果模型使用 torch.jit.scripttorch.compile,钩子可能失效。

5. 与 register_full_backward_hook 的区别

PyTorch 2.0+ 引入了 register_full_backward_hook,它更强大且兼容自动微分(autograd):

def full_backward_hook(module, grad_input, grad_output) -> Tuple[Tensor]:# 可以同时修改输入和输出梯度return modified_grad_input, modified_grad_outputhook_handle = q_proj.register_full_backward_hook(full_backward_hook)

总结

  • register_backward_hook 用于捕获或修改反向传播的梯度。
  • 典型场景:梯度调试、裁剪、自定义优化。
  • 重要步骤
    1. 定义钩子函数,处理 grad_input/grad_output
    2. 注册到目标模块(如 q_proj)。
    3. 运行 loss.backward() 触发钩子。
    4. 务必移除钩子(避免内存泄漏)。

更好的写法:钩子与装饰器

第二天写到一半发现,hook明显和装饰器更契合,便于代码复用与热插拔。

下面是一个例子:

你的想法非常好!使用 装饰器(Decorator) 来封装前向钩子的注册逻辑,可以大幅简化代码,使其更模块化和可复用。以下是完整的实现示例,包含装饰器的定义和用法:


1. 定义装饰器 register_forward_hook_decorator

这个装饰器的作用是:

  1. 自动注册前向钩子 到指定模块(如 q_proj)。
  2. 捕获模块的输出,并存储到预定义的字典中。
  3. 在函数调用结束后自动移除钩子,避免内存泄漏。
import torch
from functools import wraps
from transformers import AutoModelForCausalLM, AutoTokenizerdef register_forward_hook_decorator(module_path):"""装饰器:注册前向钩子到指定模块,并捕获其输出。参数:module_path (str): 目标模块的路径,例如 "model.layers[0].self_attn.q_proj"。"""def decorator(func):@wraps(func)def wrapper(model, *args, **kwargs):# 获取目标模块target_module = eval(f"model.{module_path}")# 存储输出的字典outputs = {}# 定义钩子函数def hook(module, input, output):outputs["captured_output"] = output# 注册钩子hook_handle = target_module.register_forward_hook(hook)try:# 调用原始函数(例如 model.generate 或 model.forward)result = func(model, *args, **kwargs)# 将捕获的输出附加到结果中(可选)if hasattr(result, "logs"):result.logs = outputselse:result.captured_output = outputs["captured_output"]return resultfinally:# 确保钩子被移除hook_handle.remove()return wrapperreturn decorator

2. 使用装饰器捕获 q_proj 的输出

示例 1:装饰 model.generate

# 加载模型
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 定义装饰函数
@register_forward_hook_decorator("model.layers[0].self_attn.q_proj")
def generate_with_hook(model, inputs, **kwargs):return model.generate(**inputs, **kwargs)# 调用函数(自动捕获输出)
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
output = generate_with_hook(model, inputs, max_new_tokens=5)# 获取捕获的输出
print("Captured q_proj output:", output.captured_output.shape)  # 例如 [batch_size, seq_len, hidden_dim]

示例 2:装饰 model.forward

@register_forward_hook_decorator("model.layers[0].self_attn.k_proj")
def forward_with_hook(model, inputs, **kwargs):return model(**inputs, **kwargs)outputs = forward_with_hook(model, inputs)
print("Captured k_proj output:", outputs.captured_output.shape)

3. 进阶:支持多模块捕获

如果需要同时捕获多个模块的输出,可以扩展装饰器:

def register_multi_hooks_decorator(module_paths):"""装饰器:注册多个前向钩子,捕获多个模块的输出。参数:module_paths (List[str]): 目标模块路径列表,例如 ["model.layers[0].self_attn.q_proj", ...]。"""def decorator(func):@wraps(func)def wrapper(model, *args, **kwargs):outputs = {}handles = []# 为每个模块注册钩子for path in module_paths:module = eval(f"model.{path}")def make_hook(name):def hook(module, input, output):outputs[name] = outputreturn hookhandles.append(module.register_forward_hook(make_hook(path)))try:result = func(model, *args, **kwargs)result.captured_outputs = outputs  # 附加所有输出return resultfinally:for handle in handles:handle.remove()return wrapperreturn decorator# 使用示例
@register_multi_hooks_decorator(["model.layers[0].self_attn.q_proj","model.layers[0].self_attn.k_proj"
])
def forward_with_multi_hooks(model, inputs):return model(**inputs)outputs = forward_with_multi_hooks(model, inputs)
print("q_proj output:", outputs.captured_outputs["model.layers[0].self_attn.q_proj"].shape)
print("k_proj output:", outputs.captured_outputs["model.layers[0].self_attn.k_proj"].shape)

4. 关键点说明

  1. 自动清理钩子
    • 使用 try/finally 确保钩子始终被移除,避免内存泄漏。
  2. 输出附加到结果
    • 将捕获的输出直接附加到模型返回的结果对象上(如 output.captured_output),方便访问。
  3. 灵活性
    • 装饰器可以动态指定任意模块路径(如 "model.layers[1].mlp.dense_h_to_4h")。

总而言之,又是被AI拷打的一天。

相关文章:

  • 国际化不生效
  • 聊聊SpringAI流式输出的底层实现?
  • 安全复健|windows常见取证工具
  • 从零开始搭建Django博客③--前端界面实现
  • 超声三维测试水箱与超声功率计:精准医疗与工业检测的核心技术支撑
  • Java基础 4.23
  • GOC 课程制作
  • YOLO数据处理
  • 树莓派超全系列教程文档--(41)树莓派config.txt旧版内存控制选项
  • system verilog 语句 耗时规则
  • MySQL 锁机制
  • SwiftUI 2.Image介绍和使用
  • leve1.4
  • C# AutoResetEvent 详解
  • HTTP:十一.HTTP认证概述
  • 内存管理(Linux程序设计)
  • 宿主机和容器 ping 不通域名解决方法
  • 51c大模型~合集120
  • 汽车可变转向比系统的全面认识
  • Linux下载与安装
  • 工程院院士应汉杰不再担任苏州大学校长
  • 林毅夫:中美经济确有脱钩风险,但“完全脱钩”可能性不大
  • 话剧《门第》将开启全国巡演:聚焦牺牲、爱与付出
  • AI换脸侵权案入选最高法典型案例:明晰人工智能使用边界
  • 美菲开始举行年度军演,外交部:菲公然站在地区国家的对立面
  • 承认出现误判,以军公布加沙救护车队遭袭事件调查结果