// 调用加法算子
Expr a =Var("a",TensorType({1},DataType::Float(32)));
Expr b =Var("b",TensorType({1},DataType::Float(32)));
Expr add_call =Call(Op::Get("add"),{a, b});
5. Function(函数定义)
作用
封装可复用的计算单元(类似 Lambda 表达式)。
用于表示模型中的子图或复合算子(如 conv2d + relu 融合)。
关键组成
字段
说明
params
输入参数列表(Array<Var>)。
body
函数体的表达式(Expr)。
ret_type
返回值的类型(如 TensorType)。
type_params
泛型类型参数(支持多态,类似 C++ 模板)。
示例
// 定义一个简单的加法函数
Var x("x",TensorType({1},DataType::Float(32)));
Var y("y",TensorType({1},DataType::Float(32)));
Expr body =Call(Op::Get("add"),{x, y});
Function add_func({x, y}, body,TensorType({1},DataType::Float(32)));
6. 节点间的协作关系
计算图示例
z = (x + y) * 2
对应的 Relay IR 结构:
变量:x、y(Var 节点)。
常量:2(Const 节点)。
调用:add(x, y) 和 multiply(add_result, 2)(Call 节点)。
函数:封装整个计算(Function 节点)。
代码实现
Var x("x",TensorType({1},DataType::Float(32)));
Var y("y",TensorType({1},DataType::Float(32)));
Expr add =Call(Op::Get("add"),{x, y});
Expr two =Const(NDArray::FromVector({2.0f}));
Expr mul =Call(Op::Get("multiply"),{add, two});
Function func({x, y}, mul,TensorType({1},DataType::Float(32)));