【pytorch扩展】CUDA自定义pytorch算子(简单demo入手)

发布于:2024-07-05 ⋅ 阅读:(14) ⋅ 点赞:(0)

Pytorch作为一款优秀的AI开发平台,提供了完备的自定义算子的规范。我们用torch开发时,经常会因为现有算子的不足限制我们idea的迸发。于是,CUDA/C++自定义pytorch算子是不得不磕了。

今天通过一个小实验来梳理自定义pytorch算子都需要做哪些准备。比如,我们做一个张量加法。
vim test_add.py

from add import sum_double_op
import torch
import time

class Timer:
    def __init__(self, op_name):
        self.begin_time = 0
        self.end_time = 0
        self.op_name = op_name

    def __enter__(self):
        torch.cuda.synchronize()
        self.begin_time = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        self.end_time = time.time()
        print(f"Average time cost of {self.op_name} is {(self.end_time - self.begin_time) * 1000:.4f} ms")


if __name__ == '__main__':
    n = 1000000
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tensor1 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)
    tensor2 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)
    with Timer("sum_double"):
        ans = sum_double_op(tensor1, tensor2)

这里的"sum_double_op"就是我们用CUDA写的算子。那这个可以直接调用,并且可以传递梯度的算子,需要怎么做呢?


众所周知,CUDA/C++都是编译性语言,编译以后再调用会比python这种解释性语言更快。所以,我们需要对CUDA有一个编译过程。这个编译过程用setuptools来实现(可以pip安装)。
先vim setup.py

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='myAdd',
    packages=find_packages(),
    version='0.1.0',
    author='muzhan',
    ext_modules=[
        CUDAExtension(
            'sum_double',
            ['./add/add.cpp',
             './add/add_cuda.cu',]
        ),
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

直接“python setup.py install”即可完成cuda算子的编译和安装。等等,你的add.cpp和add_cuda.cu还没呢?
vim add_cuda.cu

#include <cstdio>
#define THREADS_PER_BLOCK 256
#define WARP_SIZE 32
#define DIVUP(m, n) ((m + n - 1) / n)


__global__ void two_sum_kernel(const float* a, const float* b, float * c, int n){
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n){
        c[idx] = a[idx] + b[idx];
    }
}


void two_sum_launcher(const float* a, const float* b, float* c, int n){
    dim3 blockSize(DIVUP(n, THREADS_PER_BLOCK));
    dim3 threadSize(THREADS_PER_BLOCK);
    two_sum_kernel<<<blockSize, threadSize>>>(a, b, c, n);
}

vim add.cpp

#include <torch/extension.h>
#include <torch/serialize/tensor.h>

#define CHECK_CUDA(x) \
  TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
  TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
  CHECK_CUDA(x);       \
  CHECK_CONTIGUOUS(x)


void two_sum_launcher(const float* a, const float* b, float* c, int n);


void two_sum_gpu(at::Tensor a_tensor, at::Tensor b_tensor, at::Tensor c_tensor){
    CHECK_INPUT(a_tensor);
    CHECK_INPUT(b_tensor);
    CHECK_INPUT(c_tensor);

    const float* a = a_tensor.data_ptr<float>();
    const float* b = b_tensor.data_ptr<float>();
    float* c = c_tensor.data_ptr<float>();
    int n = a_tensor.size(0);
    two_sum_launcher(a, b, c, n);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &two_sum_gpu, "sum two arrays (CUDA)");
}

我们看一下文件结构:

.
├── add
│   ├── add.cpp
│   ├── add_cuda.cu
│   ├── __init__.py
│   └── sum.py
├── README.md
├── setup.py
└── test_add.py

有了add.cpp和add_cuda.cu以后,我们就可以用"python setup.py install"来进行编译和安装了。编译和安装以后,我们需要用python类封装一下:

vim __init__.py
from .sum import *

vim sum.py

from torch.autograd import Function
import sum_double


class SumDouble(Function):

    @staticmethod
    def forward(ctx, array1, array2):
        """sum_double function forward.
        Args:
            array1 (torch.Tensor): [n,]
            array2 (torch.Tensor): [n,]
        
        Returns:
            ans (torch.Tensor): [n,]
        """
        array1 = array1.float()
        array2 = array2.float()
        ans = array1.new_zeros(array1.shape)
        sum_double.forward(array1.contiguous(), array2.contiguous(), ans)

        # ctx.mark_non_differentiable(ans) # if the function is no need for backpropogation

        return ans

    @staticmethod
    def backward(ctx, g_out):
        # return None, None   # if the function is no need for backpropogation

        g_in1 = g_out.clone()
        g_in2 = g_out.clone()
        return g_in1, g_in2


sum_double_op = SumDouble.apply

最后,直接

python test_add.py