【速写】hook与fx
文章目录
- 问题
- 方法
- 方法 1:使用 PyTorch 的 `register_forward_hook`
- 方法 2:自定义前向传播(修改 `forward` 方法)
- 方法 3:使用 `output_attentions` 或 `output_hidden_states`
- 方法 4:使用 `torch.fx` 进行动态追踪
- 总结
- 前向钩子的输出到底是什么?
- 反向钩子应该怎么用?
- 更好的写法:钩子与装饰器
最近师弟要跑步,目标是减肥(大概跟我差不多高,但比我要重40斤),想着下半年能跑个半马(不过以他目前的水平来说,还是有点差距的,虽然他只是想能在关门前跑完就行了)。
有半个月了,一周跑5天,周日休息,周四下午跟他去打一个小时羽毛球,大概6-7分的配,一天跑2-3km,有点找回快乐跑步的初衷。但实在还是太慢了,不过也无所大谓。
讨论了一些问题,我很惊异于他们做diffuser的,居然融合lora就直接查一下lora权重的绝对值,然后取top-k大的保留,其余的直接dropout掉了
目前diffuser方法有很强的解释性,很多研究都表明了模型的不同层是可以近似对应图层进行解耦的,即比如第1层控制的是layout,第2层控制的是color,第3层控制的是mainbody,因此如果只是去调颜色,就只需要对第2层套lora(我之前以为这个事情目前不太好实现,其实就是target_modules里的名字写得不对,不能按照key或者model.named_modules()
来写,要省掉前面几项才行),不过也可以用笨方法冻结:
# 先对所有 q_proj 应用 LoRA
peft_config = LoraConfig(target_modules=["q_proj"])# 然后手动冻结不需要的层
for name, param in model.named_parameters():if "layers.1." in name: # 例如冻结第1层及之后的参数param.requires_grad = False
问题
这个不是重点,我的观点是,权重的数值分布并不重要( W W W不重要),输出的值分布才是更重要的( W x Wx Wx才是重要的),或者说激活后的输出才是最重要的( L a y e r N o r m ( σ ( W x + b ) ) LayerNorm(\sigma(Wx+b)) LayerNorm(σ(Wx+b))),上个月何凯明发的用tanh替代LayerNorm的那个神文里,说的也是这个道理,其实无所谓到底发生了什么,输出的分布近似tanh,那就别搞什么LayerNorm了。
那么这又是一个老生常谈的问题:
使用transformers调用DeepSeek-R1-Distill-Qwen-32B模型,可以用
model.generate
方法输出模型的结果,也可以使用model.forward
(即model.__call__
)调用,现在的问题是,如果想知道模型中每一层(比如说,每一个self_attn的q_proj,k_proj的输出结果),那么我有办法实现吗,在不修改transformers源码的情况下。
以下是几种可行的方法:
- 法1:最为常用,就是要学会用钩子,其实这是个很好的习惯,自己写模型的时候也可以提前把钩子写好,以前觉得什么
register_buffer
、register_hook
没有意义,并非如此。 - 法2:这个很奇特,不太会写
- 法3:传统方案,不管用
- 法4:用
torch.fx
,其实这个本质上就是修改了源码
方法
方法 1:使用 PyTorch 的 register_forward_hook
你可以为模型的某个子模块(如 q_proj
、k_proj
)注册一个钩子,在推理时捕获它们的输出。
示例代码
import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 存储各层输出的字典
outputs = {}def get_layer_output_hook(name):def hook(module, input, output):outputs[name] = outputreturn hook# 注册钩子(以第0层的self_attn的q_proj为例)
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
hook = q_proj.register_forward_hook(get_layer_output_hook(f"layer_{layer_idx}_q_proj"))# 运行推理
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():model.generate(**inputs, max_new_tokens=5)# 查看输出
print(outputs) # 包含 layer_0_q_proj 的输出
hook.remove() # 移除钩子
适用场景
- 适用于 单次推理 时获取特定层的输出。
- 可以注册多个钩子,捕获不同层的输出。
方法 2:自定义前向传播(修改 forward
方法)
如果你希望更灵活地控制前向传播过程,可以 临时替换 某些模块的 forward
方法,以记录中间结果。
示例代码
original_forward = None
captured_outputs = []def custom_forward(module, *args, **kwargs):global original_forwardoutput = original_forward(*args, **kwargs)captured_outputs.append(output)return output# 替换第0层 self_attn.q_proj 的 forward 方法
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
original_forward = q_proj.forward
q_proj.forward = lambda *args, **kwargs: custom_forward(q_proj, *args, **kwargs)# 运行推理
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():model.generate(**inputs, max_new_tokens=5)# 恢复原始 forward 方法
q_proj.forward = original_forwardprint(captured_outputs) # 包含 q_proj 的输出
适用场景
- 适用于需要 更灵活控制 的情况,比如在训练时动态修改某些层的行为。
- 比钩子更灵活,但需要手动管理
forward
方法的替换和恢复。
output_attentions
或 output_hidden_states
方法 3:使用 某些 HuggingFace 模型支持直接返回注意力权重或隐藏状态(但可能不适用于所有层或所有模型)。
示例代码
inputs = tokenizer("Hello, how are you?", return_tensors="pt").to(model.device)
outputs = model(**inputs,output_attentions=True, # 返回注意力权重output_hidden_states=True, # 返回所有隐藏状态
)# 获取第0层的隐藏状态
hidden_states = outputs.hidden_states[0]
print(hidden_states.shape) # [batch_size, seq_len, hidden_dim]# 获取第0层的注意力权重
attentions = outputs.attentions[0]
print(attentions.shape) # [batch_size, num_heads, seq_len, seq_len]
适用场景
- 适用于 标准 Transformer 结构,但可能无法获取
q_proj
/k_proj
的中间输出。 - 适用于快速获取 注意力权重 或 隐藏状态。
方法 4:使用 torch.fx
进行动态追踪
如果你需要 更复杂的中间结果提取,可以使用 PyTorch 的 torch.fx
动态追踪计算图。
示例代码
from torch.fx import symbolic_trace# 追踪模型的计算图
traced_model = symbolic_trace(model)# 自定义提取中间结果
class ExtractModuleOutputs(torch.nn.Module):def __init__(self, model):super().__init__()self.model = modeldef forward(self, *args, **kwargs):# 在这里手动提取中间结果return self.model(*args, **kwargs)# 运行推理
extractor = ExtractModuleOutputs(traced_model)
output = extractor(**inputs)
简单一个例子,比如还是查看第1层的q_proj的输出:
import torch
from torch.fx import symbolic_trace
from transformers import AutoModelForCausalLM, AutoTokenizer# 加载模型和分词器
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 定义一个包装类,用于提取中间输出
class ExtractModuleOutputs(torch.nn.Module):def __init__(self, model, target_module_path):super().__init__()self.model = modelself.target_module_path = target_module_pathself.captured_output = None# 使用 torch.fx 追踪计算图self.traced_model = symbolic_trace(self.model)# 在计算图中插入代码,捕获目标模块的输出self._insert_capture_code()def _insert_capture_code(self):# 找到目标模块(例如 "model.layers.0.self_attn.q_proj")target_module = eval(f"self.model.{self.target_module_path}")# 定义一个钩子函数,用于捕获输出def capture_output(module, input, output):self.captured_output = outputreturn output # 不影响原始计算# 注册钩子target_module.register_forward_hook(capture_output)def forward(self, *args, **kwargs):# 运行追踪后的模型self.captured_output = None # 清空之前的捕获output = self.traced_model(*args, **kwargs)return output# 初始化提取器,指定目标模块路径
extractor = ExtractModuleOutputs(model,target_module_path="model.layers[0].self_attn.q_proj"
)# 运行推理并捕获输出
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():logits = extractor(**inputs) # 正常前向传播q_proj_output = extractor.captured_output # 提取的 q_proj 输出print("q_proj output shape:", q_proj_output.shape) # 例如 [batch_size, seq_len, hidden_dim]
用的还是hook来找的,不过你也可以直接把源码复制一份扔到forward
里,然后手动logging一下想要的输出,不过这个事情似乎并不是很好实现,因为源码是高度解耦的,不太好直接并起来。
适用场景
- 适用于 复杂模型,需要 动态修改计算图 的情况。
- 学习成本较高,但灵活性最强。
总结
方法 | 适用场景 | 灵活性 | 实现难度 |
---|---|---|---|
钩子 (register_forward_hook ) | 单次推理,获取中间输出 | 中等 | 低 |
自定义 forward | 需要动态修改前向逻辑 | 高 | 中 |
output_attentions /output_hidden_states | 获取标准隐藏状态/注意力 | 低 | 低 |
torch.fx | 复杂模型,动态修改计算图 | 最高 | 高 |
如果你的目标是 简单获取 q_proj
/k_proj
的输出,推荐 方法 1(钩子),因为它不需要修改模型结构,且易于管理。
前向钩子的输出到底是什么?
在方法1中,一般来讲q_proj
的钩子所捕获的输出应该是 Q x Qx Qx吧,那加入权重 Q Q Q被调用了两次,比方说,有 Q x 1 Qx_1 Qx1和 Q x 2 Qx_2 Qx2,那么outputs
里就会有对应的两个输出吗?
deepseek认为我的理解是正确的!在 方法 1(钩子) 中,每次目标模块(如 q_proj
)被调用时,钩子都会捕获它的输出。因此,如果 q_proj
在 一次完整的前向传播中被调用了两次(例如计算 (Qx_1) 和 (Qx_2)),那么 outputs
字典中会 按调用顺序存储这两个输出。
关键验证实验
我们可以通过一个简单的例子验证这一点。假设 q_proj
在模型中被多次调用(例如,由于多轮注意力计算或不同的输入路径):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 存储所有调用的输出
outputs = []def hook(module, input, output):outputs.append(output)return output # 不影响原始计算# 注册钩子到第0层的 q_proj
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
hook_handle = q_proj.register_forward_hook(hook)# 模拟输入(假设模型内部会多次调用 q_proj)
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)with torch.no_grad():model(**inputs) # 正常前向传播# 检查 outputs
print(f"Number of q_proj calls: {len(outputs)}")
for i, out in enumerate(outputs):print(f"Call {i+1} output shape: {out.shape}")# 移除钩子
hook_handle.remove()
可能的结果
- 如果
q_proj
在 单次前向传播中被调用一次(例如标准的自注意力机制),outputs
会包含 1 个张量(形状如[batch_size, seq_len, hidden_dim]
)。 - 如果
q_proj
被调用多次(例如模型有 多头注意力 或 分块处理),outputs
会按调用顺序存储所有输出。
为什么会出现多次调用?
-
多头注意力(Multi-Head Attention)
- 某些实现中,
q_proj
可能被拆分为多个头(heads),但通常仍是一次性计算 (Qx)。 - 如果每个头独立调用
q_proj
(罕见),则钩子会捕获多次输出。
- 某些实现中,
-
分块处理(Chunking)
- 长序列可能被拆分为块(chunks),导致
q_proj
被多次调用(例如x1
和x2
是两块输入)。
- 长序列可能被拆分为块(chunks),导致
-
自定义模型逻辑
- 如果模型在
forward
中显式调用了多次q_proj
(例如某些动态路由机制),钩子会捕获所有调用。
- 如果模型在
如何区分不同调用的来源?
如果需要明确每次调用对应的输入(例如区分 (Qx_1) 和 (Qx_2)),可以:
- 检查输入参数(通过钩子的
input
参数):def hook(module, input, output):print(f"Input shape: {input[0].shape}") # input 是元组,取第一个元素outputs.append(output)
- 记录调用堆栈(适用于调试):
import traceback def hook(module, input, output):print(f"Call stack:\n{traceback.format_stack()[-3]}")outputs.append(output)
总结
- 钩子会捕获目标模块的所有调用。如果
q_proj
被调用 (N) 次,outputs
会存储 (N) 个输出。 - 典型情况:在标准 Transformer 中,
q_proj
通常 每层每前向传播仅调用一次(计算 (Qx))。 - 异常情况:如果模型有特殊设计(如分块、动态路由),可能会多次调用,此时需结合输入分析。
如果你的目标是 仅捕获第一次调用,可以在钩子中设置条件逻辑(例如 if len(outputs) == 0
)。
反向钩子应该怎么用?
前面的方法针对的是前向钩子,用于捕获推理时的输出值,但是反向钩子同样重要,一般用于捕获梯度。
在 PyTorch 中,register_backward_hook
用于在反向传播(即梯度计算)时捕获模块的 输入梯度、输出梯度 或 权重梯度。它的使用方式与 register_forward_hook
类似,但关注的是反向传播过程。以下是详细说明和示例:
1. register_backward_hook
的基本用法
钩子函数的签名如下:
def backward_hook(module, grad_input, grad_output) -> Tensor or None:# grad_input: 输入的梯度(通常是一个元组,对应 forward 输入的梯度)# grad_output: 输出的梯度(通常是一个元组,对应 forward 输出的梯度)# 可以返回修改后的梯度(可选)return modified_grad_input
关键点
grad_input
:是forward
输入的梯度(例如x
的梯度),格式为Tuple[Tensor]
。grad_output
:是forward
输出的梯度(例如y
的梯度),格式为Tuple[Tensor]
。- 返回值:可以修改
grad_input
并返回(例如梯度裁剪),如果不需要修改则返回None
。
2. 示例:捕获 q_proj
的反向梯度
假设我们想捕获 q_proj
在反向传播时的 输入梯度 和 输出梯度:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 存储梯度信息
gradients = {"grad_input": None,"grad_output": None
}def backward_hook(module, grad_input, grad_output):gradients["grad_input"] = grad_input # 保存输入的梯度gradients["grad_output"] = grad_output # 保存输出的梯度return None # 不修改梯度# 注册反向钩子到第0层的 q_proj
layer_idx = 0
q_proj = model.model.layers[layer_idx].self_attn.q_proj
hook_handle = q_proj.register_backward_hook(backward_hook)# 准备输入并计算损失
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model(**inputs, labels=inputs["input_ids"]) # 需要 labels 以计算 loss
loss = outputs.loss# 反向传播(触发钩子)
loss.backward()# 检查梯度
print("grad_input (q_proj 输入的梯度):", [g.shape for g in gradients["grad_input"] if g is not None])
print("grad_output (q_proj 输出的梯度):", [g.shape for g in gradients["grad_output"] if g is not None])# 移除钩子
hook_handle.remove()
输出解释
grad_input
:是q_proj.forward
的输入(即x
)的梯度,形状通常与x
相同。grad_output
:是q_proj.forward
的输出(即Qx
)的梯度,形状通常与Qx
相同。
3. 常见用途
(1) 梯度裁剪(Gradient Clipping)
def gradient_clip_hook(module, grad_input, grad_output):max_norm = 1.0grad_input = tuple(torch.clamp(g, -max_norm, max_norm) for g in grad_input)return grad_input # 返回修改后的输入梯度
(2) 梯度检查(Debugging)
def debug_hook(module, grad_input, grad_output):print(f"Module: {module.__class__.__name__}")print(f"grad_input: {[g.shape for g in grad_input]}")print(f"grad_output: {[g.shape for g in grad_output]}")return None
(3) 梯度修改(自定义优化策略)
def custom_grad_hook(module, grad_input, grad_output):# 例如:对梯度加权grad_input = tuple(g * 0.5 for g in grad_input)return grad_input
4. 注意事项
-
梯度顺序
grad_input
的顺序与forward
的输入顺序一致,但可能包含None
(例如未计算梯度的输入)。- 对于
Linear
层(如q_proj
),grad_input
通常是(x的梯度, weight的梯度, bias的梯度)
。
-
内存占用
- 反向钩子会保存梯度张量,可能增加内存消耗,调试后需及时移除钩子:
hook_handle.remove() # 必须手动移除!
- 反向钩子会保存梯度张量,可能增加内存消耗,调试后需及时移除钩子:
-
动态图 vs 静态图
- 在
torch.no_grad()
或eval()
模式下,反向钩子不会被触发。 - 如果模型使用
torch.jit.script
或torch.compile
,钩子可能失效。
- 在
5. 与 register_full_backward_hook
的区别
PyTorch 2.0+ 引入了 register_full_backward_hook
,它更强大且兼容自动微分(autograd):
def full_backward_hook(module, grad_input, grad_output) -> Tuple[Tensor]:# 可以同时修改输入和输出梯度return modified_grad_input, modified_grad_outputhook_handle = q_proj.register_full_backward_hook(full_backward_hook)
总结
register_backward_hook
用于捕获或修改反向传播的梯度。- 典型场景:梯度调试、裁剪、自定义优化。
- 重要步骤:
- 定义钩子函数,处理
grad_input
/grad_output
。 - 注册到目标模块(如
q_proj
)。 - 运行
loss.backward()
触发钩子。 - 务必移除钩子(避免内存泄漏)。
- 定义钩子函数,处理
更好的写法:钩子与装饰器
第二天写到一半发现,hook明显和装饰器更契合,便于代码复用与热插拔。
下面是一个例子:
你的想法非常好!使用 装饰器(Decorator) 来封装前向钩子的注册逻辑,可以大幅简化代码,使其更模块化和可复用。以下是完整的实现示例,包含装饰器的定义和用法:
1. 定义装饰器 register_forward_hook_decorator
这个装饰器的作用是:
- 自动注册前向钩子 到指定模块(如
q_proj
)。 - 捕获模块的输出,并存储到预定义的字典中。
- 在函数调用结束后自动移除钩子,避免内存泄漏。
import torch
from functools import wraps
from transformers import AutoModelForCausalLM, AutoTokenizerdef register_forward_hook_decorator(module_path):"""装饰器:注册前向钩子到指定模块,并捕获其输出。参数:module_path (str): 目标模块的路径,例如 "model.layers[0].self_attn.q_proj"。"""def decorator(func):@wraps(func)def wrapper(model, *args, **kwargs):# 获取目标模块target_module = eval(f"model.{module_path}")# 存储输出的字典outputs = {}# 定义钩子函数def hook(module, input, output):outputs["captured_output"] = output# 注册钩子hook_handle = target_module.register_forward_hook(hook)try:# 调用原始函数(例如 model.generate 或 model.forward)result = func(model, *args, **kwargs)# 将捕获的输出附加到结果中(可选)if hasattr(result, "logs"):result.logs = outputselse:result.captured_output = outputs["captured_output"]return resultfinally:# 确保钩子被移除hook_handle.remove()return wrapperreturn decorator
2. 使用装饰器捕获 q_proj
的输出
示例 1:装饰 model.generate
# 加载模型
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")# 定义装饰函数
@register_forward_hook_decorator("model.layers[0].self_attn.q_proj")
def generate_with_hook(model, inputs, **kwargs):return model.generate(**inputs, **kwargs)# 调用函数(自动捕获输出)
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
output = generate_with_hook(model, inputs, max_new_tokens=5)# 获取捕获的输出
print("Captured q_proj output:", output.captured_output.shape) # 例如 [batch_size, seq_len, hidden_dim]
示例 2:装饰 model.forward
@register_forward_hook_decorator("model.layers[0].self_attn.k_proj")
def forward_with_hook(model, inputs, **kwargs):return model(**inputs, **kwargs)outputs = forward_with_hook(model, inputs)
print("Captured k_proj output:", outputs.captured_output.shape)
3. 进阶:支持多模块捕获
如果需要同时捕获多个模块的输出,可以扩展装饰器:
def register_multi_hooks_decorator(module_paths):"""装饰器:注册多个前向钩子,捕获多个模块的输出。参数:module_paths (List[str]): 目标模块路径列表,例如 ["model.layers[0].self_attn.q_proj", ...]。"""def decorator(func):@wraps(func)def wrapper(model, *args, **kwargs):outputs = {}handles = []# 为每个模块注册钩子for path in module_paths:module = eval(f"model.{path}")def make_hook(name):def hook(module, input, output):outputs[name] = outputreturn hookhandles.append(module.register_forward_hook(make_hook(path)))try:result = func(model, *args, **kwargs)result.captured_outputs = outputs # 附加所有输出return resultfinally:for handle in handles:handle.remove()return wrapperreturn decorator# 使用示例
@register_multi_hooks_decorator(["model.layers[0].self_attn.q_proj","model.layers[0].self_attn.k_proj"
])
def forward_with_multi_hooks(model, inputs):return model(**inputs)outputs = forward_with_multi_hooks(model, inputs)
print("q_proj output:", outputs.captured_outputs["model.layers[0].self_attn.q_proj"].shape)
print("k_proj output:", outputs.captured_outputs["model.layers[0].self_attn.k_proj"].shape)
4. 关键点说明
- 自动清理钩子
- 使用
try/finally
确保钩子始终被移除,避免内存泄漏。
- 使用
- 输出附加到结果
- 将捕获的输出直接附加到模型返回的结果对象上(如
output.captured_output
),方便访问。
- 将捕获的输出直接附加到模型返回的结果对象上(如
- 灵活性
- 装饰器可以动态指定任意模块路径(如
"model.layers[1].mlp.dense_h_to_4h"
)。
- 装饰器可以动态指定任意模块路径(如
总而言之,又是被AI拷打的一天。