InferType和_checked_type的区别?
在 TVM 的 Relay IR 中,relay.frontend.common.infer_shape(node)
和 node.checked_type.shape
都与**形状(Shape)**信息相关,但它们的用途、实现机制和性能特点有显著区别。以下是详细对比:
1. 功能区别
特性 | node.checked_type.shape | relay.frontend.common.infer_shape(node) |
---|---|---|
数据来源 | 直接从节点的 _checked_type_ 中读取形状信息 | 动态计算节点的输出形状(可能触发类型推断) |
依赖条件 | 要求 node._checked_type_ 已被正确填充(通过 InferType ) | 不依赖 _checked_type_ ,独立计算形状 |
返回值 | 返回静态类型中存储的形状(TensorType.shape ) | 返回动态推断的形状(可能包含变量或符号维度) |
适用场景 | 快速访问已知形状(如优化 pass 中) | 需要动态推断形状(如前端模型导入时) |
2. 实现原理
node.checked_type.shape
- 直接访问属性:
从节点的_checked_type_
(类型为TensorType
)中直接读取shape
字段。
例如:# 假设 node._checked_type_ = TensorType([1, 3, 224, 224], "float32") print(node.checked_type.shape) # 输出: [1, 3, 224, 224]
- 性能:
时间复杂度为O(1)
,仅是属性访问,无额外计算。
relay.frontend.common.infer_shape(node)
- 动态形状推断:
通过遍历节点的输入依赖关系,递归计算输出形状。可能涉及:- 算子形状推导规则(如
conv2d
的输入/输出形状关系)。 - 符号形状的处理(如动态 batch 维度
n
)。
- 算子形状推导规则(如
- 性能:
时间复杂度为O(N)
(N
为依赖的子图节点数),需递归计算。
3. 运行时长的比较
场景 | node.checked_type.shape | infer_shape(node) |
---|---|---|
类型已推断(_checked_type_ 已填充) | 极快(微秒级) | 较慢(需重新计算) |
类型未推断(_checked_type_ = None ) | 返回 None 或报错 | 必须调用,耗时长 |
动态形状(含符号维度) | 仅返回静态记录的形状 | 支持动态推断 |
关键结论:
checked_type.shape
更快:
如果类型已推断,直接读取属性比重新计算快数个数量级。infer_shape
更通用但更慢:
适用于未知类型或动态形状,但需付出计算代价。
4. **何时使用哪个?
优先使用 node.checked_type.shape
当:
- 确定
InferType
已运行(如优化 pass 中)。 - 需要频繁访问形状(如循环中)。
必须使用 infer_shape(node)
当:
- 节点类型未推断(如前端模型导入时)。
- 处理动态形状(如
shape=[n, 224, 224]
)。 - 需要验证形状一致性(如自定义算子开发)。
5. **代码示例
示例 1:静态形状访问(高效)
# 假设已调用 InferType
mod = relay.transform.InferType()(mod)
node = mod["main"].body
print(node.checked_type.shape) # 直接读取
示例 2:动态形状推断(必要时)
# 前端导入模型时(类型未推断)
shape = relay.frontend.common.infer_shape(node)
print(shape) # 动态计算
6. **性能对比实验
以下是一个简单的性能测试:
import time
import tvm
from tvm import relay# 构造一个计算图
x = relay.var("x", shape=[1, 3, 224, 224], dtype="float32")
y = relay.nn.relu(x)
mod = tvm.IRModule.from_expr(y)# 案例1:使用 checked_type.shape(需先推断类型)
mod = relay.transform.InferType()(mod)
node = mod["main"].bodystart = time.time()
for _ in range(1000):_ = node.checked_type.shape
print("checked_type.shape:", time.time() - start) # 约 0.0001s# 案例2:使用 infer_shape
start = time.time()
for _ in range(1000):_ = relay.frontend.common.infer_shape(node)
print("infer_shape:", time.time() - start) # 约 0.1s
结果:
checked_type.shape
比 infer_shape
快约 1000 倍。
7. **注意事项
- 一致性风险:
若手动修改了图结构(如删除节点),需重新调用InferType
,否则checked_type.shape
可能过期。 - 动态形状限制:
checked_type.shape
无法处理符号维度(如n
),而infer_shape
可以。
总结
- 速度:
node.checked_type.shape
远快于infer_shape(node)
(前提是类型已推断)。 - 灵活性:
infer_shape(node)
支持动态场景,但代价较高。 - 最佳实践:
- 在优化 pass 中优先使用
checked_type.shape
。 - 在前端导入或处理动态形状时使用
infer_shape
。
- 在优化 pass 中优先使用