pytorch支持更多onnx算子

发布于:2024-06-16 ⋅ 阅读:(26) ⋅ 点赞:(0)

pytorch支持更多onnx算子

本文主要参考扩展onnx算子

而要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:

  • 算子在 PyTorch 中有实现
  • 有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法
  • ONNX 有相应的算子

PyTorch 算子

  • 组合现有算子
  • 添加 TorchScript 算子
  • 添加普通 C++ 拓展算子
    映射方法
  • 为 ATen 算子添加符号函数
  • 为 TorchScript 算子添加符号函数
  • 封装成 torch.autograd.Function 并添加符号函数
    ONNX 算子
  • 使用现有 ONNX 算子
  • 定义新 ONNX 算子

支持ATen算子

ATen 是 PyTorch 内置的 C++ 张量计算库,PyTorch 算子在底层绝大多数计算都是用 ATen 实现的。

针对的问题:ATen有定义,但缺少和ONNX的映射规则。
解决的思路:

  1. 获取Aten算子接口定义。去 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi搜索算子名。如asinh,对应的接口为def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
  2. 添加符号函数
    添加符号函数def symbolic(g: torch._C.Graph, input_0: torch._C.Value, input_1: torch._C.Value, ...): ,g有一个op方法,在把 PyTorch 算子转换成 ONNX 算子时,需要在符号函数中调用此方法来为最终的计算图添加一个 ONNX 算子。在最简单的情况下,我们只要把 PyTorch 算子的输入用g.op()一一对应到 ONNX 算子上即可,并把g.op()的返回值作为符号函数的返回值。在情况更复杂时,我们转换一个 PyTorch 算子可能要新建若干个 ONNX 算子。我们先去翻阅一下 ONNX 算子文档,学习一下我们在符号函数里的映射关系 g.op() 里应该怎么写。Asinh 的文档写道:该算子有一个输入 input,一个输出 output,二者的类型都为张量。

代码汇总如下

import torch 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
 
    def forward(self, x): 
        return torch.asinh(x) 
 
from torch.onnx.symbolic_registry import register_op 
 
def asinh_symbolic(g, input, *, out=None): 
    return g.op("Asinh", input) 
 
register_op('asinh', asinh_symbolic, '', 9) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, input, 'asinh.onnx') 

自定义算子

针对的问题:ONNX中没有对应算子的定义,需要自定义ONNX算子,执行转换。

g.op() 是用来定义 ONNX 算子的函数,对于 ONNX 官方定义的算子,g.op() 的第一个参数就是该算子的名称。而对于一个自定义算子,g.op() 的第一个参数是一个带命名空间的算子名。

完整代码

import torch 
import torchvision 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.conv1 = torch.nn.Conv2d(3, 18, 3) 
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3) 
 
    def forward(self, x): 
        return self.conv2(x, self.conv1(x)) 
 
from torch.onnx import register_custom_op_symbolic 
from torch.onnx.symbolic_helper import parse_args 
 
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none") 
def symbolic(g,  
        input, 
        weight, 
        offset, 
        mask, 
        bias, 
        stride_h, stride_w, 
        pad_h, pad_w, 
        dil_h, dil_w, 
        n_weight_grps, 
        n_offset_grps, 
        use_mask): 
    return g.op("custom::deform_conv2d", input, offset) 
 
register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, input, 'dcn.onnx')