解决qnn htp 后端不支持boolean 数据类型的方法。
一、背景
1.1 问题原因
Qnn 模型在使用fp16的模型转换不支持类型是boolean的cast 算子,因为 htp 后端支持量化数据类型或者fp16,不支持boolean 类型。
${QNN_SDK_ROOT_27}/bin/x86_64-linux-clang/qnn-model-lib-generator -c ./bge_small_fp16.cpp -b ./bge_small_fp16.bin -o output-so-small
也就是图中的算子不支持。
尝试了很多版本,后端,都不支持。没办法只能算子替换了。
1.2 替换算子
初步思路:
Sub↓Cast (to bool)↓Cast (to float32) (另外一个输入,假设是 y)↓ ↓Mul Mul (1 - mask)↓ ↓Add↓Output
-
先做一个 Greater 比较,生成 0/1 tensor
-
再用这个 0/1 tensor 进行
(cond * x) + ((1-cond) * y)
操作, Where(cond, x, y) = cond * x + (1 - cond) * y 可以用Cast
+Mul
+Sub
+Add
基础算子实现。 -
但是生成的还是有boolean 类型数据
不要 Greater
(即不要比较生成bool类型)
不要 BOOL
tensor (因为有些平台对BOOL类型支持不好,比如QNN/DSP/NPU)
直接从 float tensor 生成 0/1 的 float tensor!
改进思路:
可以直接用 Clip + Sign 这种基础算子来实现!
比如:
-
Sign(x)
:-
如果 x > 0,输出 1
-
如果 x == 0,输出 0
-
如果 x < 0,输出 -1
-
-
Clip(Sign(x), 0.0, 1.0)
:-
把负数剪到 0
-
正数(1)保留为 1
-
这样就完美地直接生成了一个 全是 0 或 1 的 FLOAT tensor! ✅ 没有 BOOL 类型,✅ 没有 Greater 节点,✅ 没有 Cast,✅ 全是 float。
real_cond_input ---> Sign ---> Clip(0.0, 1.0) ---> mask (float 0/1 tensor)
二、算子代码实现
1.1 替换算子
import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as npdef add_value_info(graph, name, dtype, shape):"""辅助函数:添加中间 tensor 的 shape 和 dtype"""vi = helper.make_tensor_value_info(name, dtype, shape)graph.value_info.append(vi)def add_constant(graph, base_name, value, dtype, shape):const_name = base_name + "_value"const_tensor = helper.make_tensor(name=const_name,data_type=dtype,dims=shape,vals=value)const_node = helper.make_node('Constant',inputs=[],outputs=[const_name],value=const_tensor)graph.node.append(const_node)add_value_info(graph, const_name, dtype, shape)return const_name
def replace_where_and_cast(model_path, output_path):"""替换 onnx 中的 Where 和 Cast 节点,保持功能等效"""# 读取模型model = onnx.load(model_path)nodes = model.graph.nodeprint("old model node number" + str(len(model.graph.node)))new_nodes = []nodes_to_remove = []input_shape = [1,1, 512, 512]for node in model.graph.node:if node.op_type == "Where":# 记录要移除的原始 Wherenodes_to_remove.append(node)# Where输入:[condition, x, y]cond_input = node.input[0]print(cond_input)x_input = node.input[1]print(x_input)y_input = node.input[2]print(y_input)output_name = node.output[0]print(output_name)# 处理可能前面有 Cast 的情况real_cond_input = cond_inputfor sub_node in model.graph.node:if sub_node.output and sub_node.output[0] == cond_input and sub_node.op_type == "Cast":real_cond_input = sub_node.input[0]nodes_to_remove.append(sub_node)break# ========== 关键步骤 ==========# 1. Signsign_output = real_cond_input + "_sign"sign_node = helper.make_node('Sign',inputs=[real_cond_input],outputs=[sign_output],name ="sign_add_my")new_nodes.append(sign_node)add_value_info(model.graph, sign_output, TensorProto.FLOAT, input_shape)# 2. Clip(0,1)clip_output = real_cond_input + "_clip"clip_min_tensor_name = real_cond_input + "_min_value"clip_min_initializer = numpy_helper.from_array(np.zeros(1, dtype=np.float32),name=clip_min_tensor_name)clip_max_tensor_name = real_cond_input + "_max_value"clip_max_initializer = numpy_helper.from_array(np.ones(1, dtype=np.float32),name=clip_max_tensor_name)model.graph.initializer.append(clip_min_initializer)model.graph.initializer.append(clip_max_initializer)# min_val_const_node = add_constant(model.graph, "min_value", 0, TensorProto.FLOAT, input_shape)# max_val_const_node = add_constant(model.graph, "max_value", 1, TensorProto.FLOAT, input_shape)clip_node = helper.make_node('Clip',inputs=[sign_output, clip_min_tensor_name, clip_max_tensor_name],outputs=[clip_output],name="clip_add_my")new_nodes.append(clip_node)add_value_info(model.graph, clip_output, TensorProto.FLOAT, input_shape)# 3. 生成 (1 - mask)one_tensor_name = real_cond_input + "_one"one_initializer = numpy_helper.from_array(np.ones(input_shape, dtype=np.float32),name=one_tensor_name)model.graph.initializer.append(one_initializer)one_minus_mask_output = real_cond_input + "_one_minus_mask"sub_node = helper.make_node('Sub',inputs=[one_tensor_name, clip_output],outputs=[one_minus_mask_output],name="sub_my")new_nodes.append(sub_node)add_value_info(model.graph, one_minus_mask_output, TensorProto.FLOAT, input_shape)# 4. mask * xmask_mul_x_output = real_cond_input + "_mask_mul_x"mul1_node = helper.make_node('Mul',inputs=[clip_output, x_input],outputs=[mask_mul_x_output],name="mul_my")new_nodes.append(mul1_node)add_value_info(model.graph, mask_mul_x_output, TensorProto.FLOAT, input_shape)# 5. (1-mask) * yone_minus_mask_mul_y_output = real_cond_input + "_one_minus_mask_mul_y"mul2_node = helper.make_node('Mul',inputs=[one_minus_mask_output, y_input],outputs=[one_minus_mask_mul_y_output],name="mul_my2")new_nodes.append(mul2_node)add_value_info(model.graph, one_minus_mask_mul_y_output, TensorProto.FLOAT, input_shape)# 6. 加起来得到最终输出add_node = helper.make_node('Add',inputs=[mask_mul_x_output, one_minus_mask_mul_y_output],outputs=[output_name],name="add_my")new_nodes.append(add_node)# output shape 已经有定义,不需要额外addelif node.op_type == 'Cast':# 如果是 Where 的 Cast,不保留if any(wn.input[0] == node.output[0] for wn in nodes if wn.op_type == 'Where'):print(f"Skipping Cast node: {node.name}")continueelse:new_nodes.append(node)else:new_nodes.append(node)# 移除旧节点for node in nodes_to_remove:model.graph.node.remove(node)# 更新新的节点列表model.graph.ClearField('node')model.graph.node.extend(new_nodes)print("new model node number" + str(len(model.graph.node)))# 保存新的模型onnx.save(model, output_path)if __name__ == "__main__":model_path = "./bge_small_model_simple.onnx"output_path = "./bge_replace_cast_where2.onnx"replace_where_and_cast(model_path, output_path)
2.2 运行原始模型和算子替换之后的模型
def run_bge_small_model_onnx():model = AutoModel.from_pretrained("BAAI/bge-small-zh-v1.5")tokenizers = AutoTokenizer.from_pretrained("BAAI/bge-small-zh-v1.5")input_data = "ZhongGuo, nihao, 日本再见, good cat!"device = "cuda" if torch.cuda.is_available() else "cpu"model.to(device)model.eval()input_tensor_data = tokenizers(input_data, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(device)with torch.no_grad():output = model(**input_tensor_data)print("oringal model putput")output_data = output.last_hidden_state.flatten().tolist()[:100]print(len(output.last_hidden_state.flatten().tolist()))print(output_data)print("run modify model")# 步骤 2:加载 ONNX 模型model_path = './bge_replace_cast_where2.onnx' # 替换为你的 ONNX 模型文件路径session = ort.InferenceSession(model_path)# 步骤 3:准备输入数据# 假设模型的输入是一个形状为 (1, 3, 224, 224) 的浮点张量input_name1 = session.get_inputs()[0].nameprint(input_name1)input_data1 = input_tensor_data["input_ids"].numpy()input_name2 = session.get_inputs()[1].nameinput_data2 = input_tensor_data["attention_mask"].numpy()print(input_name2)input_name3 = session.get_inputs()[2].nameinput_data3 = input_tensor_data["token_type_ids"].numpy()print(input_name3)# 步骤 4:运行模型并获取输出replace_model_output = session.run(None, {input_name1: input_data1, input_name2: input_data2, input_name3: input_data3})# 打印输出结果print("replace_model_output shape:", replace_model_output[0].shape)print("replace_model_output data:", replace_model_output[0])replace_model_output_data = replace_model_output[:100]print(len(replace_model_output))print(replace_model_output_data)np.array(replace_model_output).tofile("last_output-onnx_bge_small_replace.raw")
2.3 原始模型和替换算子模型精度对齐
def compare_nchw_data(nchw_file, nchw_file2):data_nchw = read_bin_fp32(nchw_file, shape=[1, 512, 512])print("NCHW 原始数据形状:", data_nchw.shape)print("NCHW 数据统计 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw.min(), data_nchw.max(), data_nchw.mean()))data_nchw2 = read_bin_fp32(nchw_file2, shape=[1, 512, 512])print("NHWC2 原始数据形状:", data_nchw2.shape)print("NHWC2 数据统计 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw2.min(), data_nchw2.max(), data_nchw2.mean()))diff = data_nchw - data_nchw2print("\n==== 差异对比 ====")print("差值 min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(diff.min(), diff.max(), diff.mean()))print(diff)# ==== 打印前100个数据 ====onnx_output_flat = data_nchw.flatten()onnx_output_flat2 = data_nchw2.flatten()print("\n--- 前100个元素 ---")for i in range(100):print(f"[{i}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 打印后100个数据 ====print("\n--- 后100个元素 ---")for i in range(-100, 0):idx = len(onnx_output_flat) + iprint(f"[{idx}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 可选:统计误差 ====max_diff = np.max(onnx_output_flat2 - onnx_output_flat)mean_diff = np.mean(onnx_output_flat2 - onnx_output_flat )min_diff = np.min(onnx_output_flat2 -onnx_output_flat)print(f"\n 总元素数: {onnx_output_flat.size}")print(f" 最大误差: {max_diff}")print(f" 最小误差: {min_diff}")print(f" 平均误差: {mean_diff}")
2.4 对齐结果展示
结果对齐了,表示模型替换成功了。