简单适配torch_npu不支持的ATen算子
- 一、背景说明
-
- 二、实现步骤详解
- 2.1 实现前向、反向传播算子
- 2.2 编译生成动态库
- 2.3 测试验证程序
- 三、关键点解析
-
- 四、验证结果
一、背景说明
1.1 PyTorch扩展机制
PrivateUse1
是PyTorch为第三方设备扩展设计的保留设备类型,允许开发者添加新硬件支持- 当算子在当前设备(如NPU)未实现时,PyTorch会自动回退(fallback)到CPU执行
- 本文以
native_batch_norm
算子为例,演示如何为NPU设备添加自定义实现
1.2 核心概念
- ATen:PyTorch的核心张量运算库,提供超过2000个基础算子
- 内存格式:描述张量在内存中的排布方式,如NCHW(批处理x通道x高度x宽度)
- 自动微分:PyTorch通过记录计算图实现反向传播,需要同时实现前向和反向算子
二、实现步骤详解
2.1 实现前向、反向传播算子
cat > native_batch_norm_npu.cpp <<-'EOF'
#include <torch/library.h>
#include <ATen/EmptyTensor.h>
#include <ATen/Device.h>
#include <ATen/Utils.h>
#include <ATen/native/Resize.h>
#include <c10/core/DeviceType.h> std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_npu(const at::Tensor& input, const c10::optional<at::Tensor>& weight, const c10::optional<at::Tensor>& bias, const c10::optional<at::Tensor>& running_mean, const c10::optional<at::Tensor>& running_var, bool training, double momentum, double eps)
{at::Tensor output = at::empty_like(input);at::Tensor dummy_mean = at::empty