【速写】钩子与计算图
文章目录
- 前向钩子
- 反向钩子的输入
- 反向钩子的输出
前向钩子
下面是一个测试用的计算图的网络,这里因为模型是自定义的缘故,可以直接把前向钩子注册在模型类里面,这样会更加方便一些。其实像以前BERT之类的last_hidden_state
以及pool_output
之类的输出应该也是用钩子钩出来的。
import torch
from torch import nn
from torch.nn import functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear_1 = nn.Linear(4, 3)self.linear_2 = nn.Linear(3, 3)self.linear_3 = nn.Linear(3, 3)self.linear_4 = nn.Linear(3, 1)self._register_hooks(["linear_3"])def _register_hooks(self, module_names):self.hook_outputs = {}def make_hook(name):def hook(module, input_, output):self.hook_outputs[name]["input"].append(input_)self.hook_outputs[name]["output"].append(output)return hookfor module_name in module_names:self.hook_outputs[module_name] = {"input": [], "output": []}eval(f"self.{module_name}").register_forward_hook(make_hook(module_name))def forward(self, x):y_1 = self.linear_1(x)y_1_a = F.sigmoid(y_1)y_2 = self.linear_2(y_1_a)y_2_a = F.sigmoid(y_2)print(y_1_a)print(y_2_a)y_3_1 = self.linear_3(y_1_a)print(y_3_1)y_3_2 = self.linear_3(y_2_a)print(y_3_2)x_4 = F.sigmoid(y_3_1) + F.sigmoid(y_3_2)y_4 = self.linear_4(x_4)y_4_a = F.sigmoid(y_4)return y_4_a
x = torch.FloatTensor([[1,2,3,4]])
net = Net()
y = net(x)
可视化是:
输出结果:
y_1_a: tensor([[0.2428, 0.5258, 0.2866]], grad_fn=<SigmoidBackward0>)
y_2_a: tensor([[0.4860, 0.4801, 0.6515]], grad_fn=<SigmoidBackward0>)
y_3_1: tensor([[ 0.3423, 0.2477, -0.7132]], grad_fn=<AddmmBackward0>)
y_3_2: tensor([[ 0.4148, 0.2024, -0.9481]], grad_fn=<AddmmBackward0>)
而钩子抓到的结果net.hook_outputs
中的内容形如:
{'linear_3': {'input': [[tensor([[0.2428, 0.5258, 0.2866]])],[tensor([[0.4860, 0.4801, 0.6515]])]],'output': [tensor([[ 0.3423, 0.2477, -0.7132]]),tensor([[ 0.4148, 0.2024, -0.9481]])]}}
tensor([[0.5139, 0.5634, 0.6205]], grad_fn=<SigmoidBackward0>)
tensor([[0.3508, 0.2681, 0.4771]], grad_fn=<SigmoidBackward0>)
tensor([[ 0.1624, -0.3406, 0.4669]], grad_fn=<AddmmBackward0>)
tensor([[ 0.2090, -0.4021, 0.3506]], grad_fn=<AddmmBackward0>)
{'linear_3': {'input': [(tensor([[0.3508, 0.2681, 0.4771]], grad_fn=<SigmoidBackward0>),),(tensor([[0.5139, 0.5634, 0.6205]], grad_fn=<SigmoidBackward0>),)],'output': [tensor([[ 0.1624, -0.3406, 0.4669]], grad_fn=<AddmmBackward0>),tensor([[ 0.2090, -0.4021, 0.3506]], grad_fn=<AddmmBackward0>)]}}
是完全对的上的,尽管L3被多次调用,但实际上每次调用都是1个输入1个输出,但是input
钩到的是tuple
,但output
钩到的却是tensor
但是假如我稍作修改,比如把linear_3
的单独化成一个模块self.m = M1()
,它有两个输入,也有两个输出:
class M1(nn.Module):def __init__(self):super(M1, self).__init__()self.linear_3 = nn.Linear(3, 3)def forward(self, y_1_a, y_2_a):y_3_1 = self.linear_3(y_1_a)y_3_2 = self.linear_3(y_2_a)return y_3_1, y_3_2class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear_1 = nn.Linear(4, 3)self.linear_2 = nn.Linear(3, 3)self.m = M1()self.linear_4 = nn.Linear(3, 1)self._register_hooks(["m"])def _register_hooks(self, module_names):self.hook_outputs = {}def make_hook(name):def hook(module, input_, output):self.hook_outputs[name]["input"].append(input_)self.hook_outputs[name]["output"].append(output)return hookfor module_name in module_names:self.hook_outputs[module_name] = {"input": [], "output": []}eval(f"self.{module_name}").register_forward_hook(make_hook(module_name))def forward(self, x):y_1 = self.linear_1(x)y_1_a = F.sigmoid(y_1)y_2 = self.linear_2(y_1_a)y_2_a = F.sigmoid(y_2)print(y_1_a)print(y_2_a)y_3_1, y_3_2 = self.m(y_1_a, y_2_a)x_4 = F.sigmoid(y_3_1) + F.sigmoid(y_3_2)y_4 = self.linear_4(x_4)y_4_a = F.sigmoid(y_4)return y_4_ax = torch.FloatTensor([[1,2,3,4]])
net = Net()
y = net(x)
from pprint import pprint
pprint(net.hook_outputs)
此时输出结果就是:
tensor([[0.6084, 0.6544, 0.6909]], grad_fn=<SigmoidBackward0>)
tensor([[0.2917, 0.4068, 0.2910]], grad_fn=<SigmoidBackward0>)
tensor([[-0.0419, 0.2307, -0.3825]], grad_fn=<AddmmBackward0>)
tensor([[-0.0515, 0.4510, 0.0154]], grad_fn=<AddmmBackward0>)
{'m': {'input': [(tensor([[0.6084, 0.6544, 0.6909]], grad_fn=<SigmoidBackward0>),tensor([[0.2917, 0.4068, 0.2910]], grad_fn=<SigmoidBackward0>))],'output': [(tensor([[-0.0419, 0.2307, -0.3825]], grad_fn=<AddmmBackward0>),tensor([[-0.0515, 0.4510, 0.0154]], grad_fn=<AddmmBackward0>))]}}
发现两次结果的区别了没有?观察两次钩子的输出:
第一次:
{'linear_3': {'input': [(tensor([[0.3508, 0.2681, 0.4771]], grad_fn=<SigmoidBackward0>),),(tensor([[0.5139, 0.5634, 0.6205]], grad_fn=<SigmoidBackward0>),)],'output': [tensor([[ 0.1624, -0.3406, 0.4669]], grad_fn=<AddmmBackward0>),tensor([[ 0.2090, -0.4021, 0.3506]], grad_fn=<AddmmBackward0>)]}}
第二次:
{'m': {'input': [(tensor([[0.6084, 0.6544, 0.6909]], grad_fn=<SigmoidBackward0>),tensor([[0.2917, 0.4068, 0.2910]], grad_fn=<SigmoidBackward0>))],'output': [(tensor([[-0.0419, 0.2307, -0.3825]], grad_fn=<AddmmBackward0>),tensor([[-0.0515, 0.4510, 0.0154]], grad_fn=<AddmmBackward0>))]}}
是的,input
不管怎么样,总是默认是一个tuple
,哪怕里面只有一个输入张量,但是输出output
第一次其实就是tensor
,第二次则变成了tuple
也就说,如果一个module
有多个输出的时候,依然是会变成tuple
的。
反向钩子的输入
把上面的代码稍作修改,我们添加反向钩子,然后随便写一个损失输出出来进行反向传播,看看情况:
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear_1 = nn.Linear(4, 3)self.linear_2 = nn.Linear(3, 3)self.linear_3 = nn.Linear(3, 3)self.linear_4 = nn.Linear(3, 1)self._register_hooks(["linear_3"])def _register_hooks(self, module_names):self.forward_hook_outputs = {}self.backward_hook_outputs = {}def make_forward_hook(name):def hook(module, input_, output):self.forward_hook_outputs[name]["input"].append(input_)self.forward_hook_outputs[name]["output"].append(output)return hookdef make_backward_hook(name):def hook(module, input_, output):self.backward_hook_outputs[name]["input"].append(input_)self.backward_hook_outputs[name]["output"].append(output)return hookfor module_name in module_names:self.forward_hook_outputs[module_name] = {"input": [], "output": []}self.backward_hook_outputs[module_name] = {"input": [], "output": []}eval(f"self.{module_name}").register_forward_hook(make_forward_hook(module_name))eval(f"self.{module_name}").register_backward_hook(make_backward_hook(module_name))def forward(self, x):y_1 = self.linear_1(x)y_1_a = F.sigmoid(y_1)y_2 = self.linear_2(y_1_a)y_2_a = F.sigmoid(y_2)print(y_1_a)print(y_2_a)y_3_1 = self.linear_3(y_1_a)y_3_2 = self.linear_3(y_2_a)print(y_3_1)print(y_3_2)x_4 = F.sigmoid(y_3_1) + F.sigmoid(y_3_2)y_4 = self.linear_4(x_4)y_4_a = F.sigmoid(y_4)return y_4_a, (y_4_a - torch.FloatTensor([[1]])) ** 2x = torch.FloatTensor([[1,2,3,4]])
net = Net()
y, loss = net(x)
loss.backward()from pprint import pprintpprint(net.forward_hook_outputs)
pprint(net.backward_hook_outputs)
输出结果:
tensor([[0.7906, 0.2277, 0.3887]], grad_fn=<SigmoidBackward0>)
tensor([[0.5084, 0.4351, 0.3494]], grad_fn=<SigmoidBackward0>)
tensor([[-0.0058, -0.2894, -0.4183]], grad_fn=<AddmmBackward0>)
tensor([[-0.2083, -0.2993, -0.3650]], grad_fn=<AddmmBackward0>)
{'linear_3': {'input': [(tensor([[0.7906, 0.2277, 0.3887]], grad_fn=<SigmoidBackward0>),),(tensor([[0.5084, 0.4351, 0.3494]], grad_fn=<SigmoidBackward0>),)],'output': [tensor([[-0.0058, -0.2894, -0.4183]], grad_fn=<AddmmBackward0>),tensor([[-0.2083, -0.2993, -0.3650]], grad_fn=<AddmmBackward0>)]}}
{'linear_3': {'input': [(tensor([ 0.0061, -0.0089, -0.0080]),tensor([[ 0.0085, 0.0027, -0.0065]]),tensor([[ 0.0031, -0.0045, -0.0041],[ 0.0026, -0.0039, -0.0035],[ 0.0021, -0.0031, -0.0028]])),(tensor([ 0.0061, -0.0089, -0.0079]),tensor([[ 0.0085, 0.0027, -0.0064]]),tensor([[ 0.0048, -0.0071, -0.0063],[ 0.0014, -0.0020, -0.0018],[ 0.0024, -0.0035, -0.0031]]))],'output': [(tensor([[ 0.0061, -0.0089, -0.0080]]),),(tensor([[ 0.0061, -0.0089, -0.0079]]),)]}}
发现问题了没有:
反向钩子,哪怕output
只有一个元素,也是返回的是tuple
,而非tensor
以这个例子里linear_3
为例,它有两个输入(y_1_a
和y_2_a
),同时也有两个输出(y_3_1
和y_3_2
),所以它在反向传播的时候,会产生两次输入输出对,体现在上面捕获的时候net.backward_hook_outputs["linear_3"]["input"]
与net.backward_hook_outputs["linear_3"]["input"]
两个列表的长度都是2:
{'linear_3': {'input': [(tensor([ 0.0061, -0.0089, -0.0080]),tensor([[ 0.0085, 0.0027, -0.0065]]),tensor([[ 0.0031, -0.0045, -0.0041],[ 0.0026, -0.0039, -0.0035],[ 0.0021, -0.0031, -0.0028]])),(tensor([ 0.0061, -0.0089, -0.0079]),tensor([[ 0.0085, 0.0027, -0.0064]]),tensor([[ 0.0048, -0.0071, -0.0063],[ 0.0014, -0.0020, -0.0018],[ 0.0024, -0.0035, -0.0031]]))],'output': [(tensor([[ 0.0061, -0.0089, -0.0080]]),),(tensor([[ 0.0061, -0.0089, -0.0079]]),)]}}
反向传播的梯度里,linear_3
这一层的input
每个都有3个张量,形状分别是(3, ), (1, 3), (3, 3)
这个稍许有点费解,我不太能理解为什么会有3个,但是如果要用的话,因为linear_3
的形状是(3, 3)
,所以它的W对应的梯度就是(3, 3)
,另外两个里面还有一个是bias项的梯度(应该就是(3, )
的那个,即第1个),那剩下还有一个会是什么呢?
问了一下deepseek,结论是其实剩下的那个(1, 3)
是输入 x x x的梯度,这倒是也比较合理
对于 Linear
层的反向钩子(register_backward_hook
),其 grad_input
参数的组成确实需要仔细分析。下面我们彻底解析这个问题:
1. Linear 层的梯度构成
对于一个 Linear(in_features, out_features)
层,其前向计算为:
y = x W T + b y = xW^T + b y=xWT+b
反向传播时,PyTorch 会计算以下梯度:
- 输入梯度( ∂ L ∂ x \frac{\partial L}{\partial x} ∂x∂L)
- 权重梯度( ∂ L ∂ W \frac{\partial L}{\partial W} ∂W∂L)
- 偏置梯度( ∂ L ∂ b \frac{\partial L}{\partial b} ∂b∂L)
在反向钩子的 grad_input
元组中,三个张量的顺序固定为:
- 输入 x x x 的梯度(形状与 x x x 相同)
- 权重 W W W 的梯度(形状与 W W W 相同,即
(out_features, in_features)
) - 偏置 b b b 的梯度(形状与 b b b 相同,即
(out_features,)
)
2. 你的具体案例解析
你观察到 Linear(3, 3)
层的 grad_input
包含三个张量:
(3, )
→ 偏置 b b b 的梯度
(因为bias
的形状是(3,)
)(1, 3)
→ 输入 x x x 的梯度
(原始输入形状为(1, 3)
,所以梯度形状相同)(3, 3)
→ 权重 W W W 的梯度
(权重矩阵的形状是(3, 3)
)
为什么输入梯度是 (1, 3)
?
- 你的前向输入 x x x 的形状可能是
(batch_size=1, in_features=3)
,因此梯度形状与之匹配。
3. 验证实验
通过以下代码可以明确验证梯度顺序:
import torch
import torch.nn as nn# 定义线性层
linear = nn.Linear(3, 3, bias=True)# 注册反向钩子
def backward_hook(module, grad_input, grad_output):print("\nGradients for Linear layer:")print(f"1. Input grad: shape={grad_input[0].shape} (x的梯度)")print(f"2. Weight grad: shape={grad_input[1].shape} (W的梯度)")print(f"3. Bias grad: shape={grad_input[2].shape} (b的梯度)")return None # 不修改梯度handle = linear.register_backward_hook(backward_hook)# 模拟前向和反向
x = torch.randn(1, 3, requires_grad=True) # 输入形状 (1, 3)
y = linear(x)
loss = y.sum()
loss.backward()handle.remove() # 移除钩子
输出结果
# Gradients for Linear layer:
# 1. Input grad: shape=torch.Size([1, 3]) (x的梯度)
# 2. Weight grad: shape=torch.Size([3, 3]) (W的梯度)
# 3. Bias grad: shape=torch.Size([3]) (b的梯度)
4. 特殊情况分析
(1) 当 bias=False
时
如果线性层无偏置(bias=False
),grad_input
将只包含两个张量:
- 输入梯度
- 权重梯度
(2) 批量输入的情况
若输入 x x x 的形状为 (batch_size, 3)
,输入梯度的形状会变为 (batch_size, 3)
:
x = torch.randn(5, 3, requires_grad=True) # 批量大小=5
y = linear(x)
loss = y.sum()
loss.backward()# 输出:
# 1. Input grad: shape=torch.Size([5, 3])
# 2. Weight grad: shape=torch.Size([3, 3])
# 3. Bias grad: shape=torch.Size([3])
5. 梯度计算原理
PyTorch 的 Linear
层反向传播公式:
- 输入梯度:
∂ L ∂ x = ∂ L ∂ y ⋅ W \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot W ∂x∂L=∂y∂L⋅W - 权重梯度:
∂ L ∂ W = ∂ L ∂ y T ⋅ x \frac{\partial L}{\partial W} = \frac{\partial L}{\partial y}^T \cdot x ∂W∂L=∂y∂LT⋅x - 偏置梯度:
∂ L ∂ b = ∑ b a t c h ∂ L ∂ y \frac{\partial L}{\partial b} = \sum_{batch} \frac{\partial L}{\partial y} ∂b∂L=batch∑∂y∂L
6. 总结
grad_input 顺序 | 张量含义 | 形状示例(Linear(3,3) ) |
---|---|---|
第 0 项 | 输入 x x x 的梯度 | (batch_size, 3) |
第 1 项 | 权重 W W W 的梯度 | (3, 3) |
第 2 项 | 偏置 b b b 的梯度 | (3,) |
注:这个存疑,我不知道这个顺序是不是按照它这个顺序来的,我感觉就观察结果而言,应该第0项是bias,第1项是x,第2项是W。持保留意见。
关键结论
- 你观察到的
(3,)
是偏置梯度,(1,3)
是输入梯度,(3,3)
是权重梯度。 - 顺序是固定的,与
forward
的输入/参数顺序无关。 - 如果层无偏置,
grad_input
长度会减 1。
反向钩子的输出
上面已经发现了,反向钩子的输出不管输出结果是一项还是多项,输出总是一个tuple(与前向钩子是不同的),这个输出本身是loss关于 y y y的梯度
y = x W ⊤ + b y=xW^\top+b y=xW⊤+b
在反向传播过程中,反向钩子(register_backward_hook
)捕获的 grad_output
本质上是 损失函数对模块原始输出的梯度,数学上表示为:
grad_output = ∂ L ∂ y \text{grad\_output} = \frac{\partial \mathcal{L}}{\partial y} grad_output=∂y∂L
其中:
- L \mathcal{L} L 是损失函数(标量)
- y y y 是模块的前向输出(可能是张量或元组)
(1) 单输出模块(如 Linear
层)
- 前向计算:
y = x W T + b (假设输入 x ∈ R B × d in , W ∈ R d out × d in , b ∈ R d out ) y = xW^T + b \quad \text{(假设输入 } x \in \mathbb{R}^{B \times d_{\text{in}}}, W \in \mathbb{R}^{d_{\text{out}} \times d_{\text{in}}}, b \in \mathbb{R}^{d_{\text{out}}}) y=xWT+b(假设输入 x∈RB×din,W∈Rdout×din,b∈Rdout) - 反向梯度:
grad_output = ( ∂ L ∂ y ) ∈ R B × d out \text{grad\_output} = \left( \frac{\partial \mathcal{L}}{\partial y} \right) \in \mathbb{R}^{B \times d_{\text{out}}} grad_output=(∂y∂L)∈RB×dout
PyTorch 会将其包装为单元素元组:(grad_output,)
(2) 多输出模块(如 LSTM
)
- 前向输出:
y = ( h all , ( h n , c n ) ) (输出序列、最后隐状态和细胞状态) y = (h_{\text{all}}, (h_n, c_n)) \quad \text{(输出序列、最后隐状态和细胞状态)} y=(hall,(hn,cn))(输出序列、最后隐状态和细胞状态) - 反向梯度:
grad_output = ( ∂ L ∂ h all , ∂ L ∂ h n , ∂ L ∂ c n ) \text{grad\_output} = \left( \frac{\partial \mathcal{L}}{\partial h_{\text{all}}}, \frac{\partial \mathcal{L}}{\partial h_n}, \frac{\partial \mathcal{L}}{\partial c_n} \right) grad_output=(∂hall∂L,∂hn∂L,∂cn∂L)
每个分量的形状与前向输出的对应张量形状一致。
PyTorch 的反向钩子(register_backward_hook
)和前向钩子(register_forward_hook
)在输出处理上确实存在这种关键差异。下面我们彻底解析这种设计差异的原因和具体行为:
1. 反向钩子的输出行为
(1) 输出始终是 tuple
无论模块的原始输出是单个张量还是元组,反向钩子的 grad_output
参数 总是以 tuple
形式传递,即使只有一个梯度张量。例如:
def backward_hook(module, grad_input, grad_output):print(type(grad_output)) # 永远是 <class 'tuple'>return None
(2) 结构对应关系
模块输出类型 | 前向钩子的 output | 反向钩子的 grad_output |
---|---|---|
单个张量 | Tensor | (Tensor,) (单元素元组) |
元组/多个输出 | Tuple[Tensor,...] | Tuple[Tensor,...] (同长度) |
2. 设计原因
(1) 一致性处理
PyTorch 选择统一用 tuple
传递反向梯度,是为了:
- 避免条件判断:无论单输出还是多输出,钩子代码无需检查类型。
- 兼容自动微分系统:PyTorch 的 autograd 始终以
tuple
形式传递梯度。
(2) 与前向钩子的对比
- 前向钩子:保留原始输出类型(张量或元组),因为用户可能需要直接使用该值。
- 反向钩子:梯度计算是系统行为,统一格式更安全。
3. 验证实验
(1) 单输出模块(如 Linear
)
import torch
import torch.nn as nnlinear = nn.Linear(3, 3)def hook(module, grad_in, grad_out):print(f"Linear层 grad_out类型: {type(grad_out)}, 长度: {len(grad_out)}")return Nonelinear.register_backward_hook(hook)x = torch.randn(2, 3, requires_grad=True)
y = linear(x) # 单输出
loss = y.sum()
loss.backward()
输出:
Linear层 grad_out类型: <class 'tuple'>, 长度: 1
(2) 多输出模块(如 LSTM
)
lstm = nn.LSTM(3, 3)def hook(module, grad_in, grad_out):print(f"LSTM层 grad_out类型: {type(grad_out)}, 长度: {len(grad_out)}")return Nonelstm.register_backward_hook(hook)x = torch.randn(5, 2, 3) # (seq_len, batch, input_size)
output, (h_n, c_n) = lstm(x) # 多输出
loss = output.sum() + h_n.sum()
loss.backward()
输出:
LSTM层 grad_out类型: <class 'tuple'>, 长度: 3
4. 实际应用建议
(1) 安全访问梯度
无论模块输出类型如何,始终按元组处理:
def backward_hook(module, grad_in, grad_out):# 安全获取第一个梯度(即使单输出)grad = grad_out[0] if len(grad_out) > 0 else Nonereturn None
(2) 多输出模块的梯度顺序
对于多输出模块(如 LSTM
),grad_out
的顺序与前向输出的顺序一致:
# 前向输出顺序: (output, (h_n, c_n))
# 反向梯度顺序: (grad_output, grad_h_n, grad_c_n)
5. 深入原理
PyTorch 的 grad_output
设计源于其自动微分系统的实现:
- 计算图构建:前向传播时记录输出节点。
- 反向传播:系统统一以
tuple
形式传递梯度,即使只有一个节点。 - 钩子注入:反向钩子接收到的是系统处理后的梯度结构。
总结
特性 | 前向钩子 output | 反向钩子 grad_output |
---|---|---|
类型 | 保持原始类型 | 强制转为 tuple |
单输出处理 | 直接返回 Tensor | 返回 (Tensor,) |
多输出处理 | 返回 Tuple[Tensor,...] | 返回 Tuple[Tensor,...] |
这种设计确保了反向传播梯度处理的统一性,而前向钩子则更注重输出值的原始性。理解这一差异能帮助你更安全地编写调试工具或自定义梯度逻辑。