pybind学习

发布于:2025-07-15 ⋅ 阅读:(13) ⋅ 点赞:(0)

一 宏

1.1 AT_DISPATCH_FLOATING_TYPES        

AT_DISPATCH_FLOATING_TYPES 和AT_DISPATCH_ALL_TYPES 的作用-CSDN博客

 

AT_DISPATCH_FLOATING_TYPES 宏主要用于以下目的:

数据类型调度:根据输入张量的数据类型选择合适的数据类型进行处理。
模板编程:结合 C++ 模板编程,根据不同的数据类型生成不同的代码路径。
代码简化:减少手动写类型检查和类型转换代码的繁琐过程。
使用示例
假设我们有一个简单的 CUDA 内核函数 example_kernel,它对输入张量进行某种操作。我们希望这个内核函数可以处理 float 和 double 类型的数据。以下是如何使用 AT_DISPATCH_FLOATING_TYPES 来实现这个目标。

1. 定义 CUDA 内核
首先,我们定义一个简单的 CUDA 内核函数:

template <typename scalar_t>
__global__ void example_kernel(scalar_t* data, int64_t size) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < size) {
    data[index] *= 2; // 例如,简单地将每个元素乘以2
  }
}

2. 定义 C++ 函数并使用 AT_DISPATCH_FLOATING_TYPES
接下来,我们定义一个 C++ 函数,使用 AT_DISPATCH_FLOATING_TYPES 来调度数据类型,并调用相应的 CUDA 内核:

#include <torch/extension.h>
#include <vector>

std::vector<torch::Tensor> example_forward(torch::Tensor input) {
  const auto size = input.size(0);
  auto output = torch::zeros_like(input);

  const int threads = 1024;
  const int blocks = (size + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(input.type(), "example_forward_cuda", ([&] {
    example_kernel<scalar_t><<<blocks, threads>>>(
        input.data<scalar_t>(), size);
  }));

  return {output};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &example_forward, "Example forward");
}


在这段代码中:

AT_DISPATCH_FLOATING_TYPES(input.type(), "example_forward_cuda", ([&] { ... })):
input.type():获取输入张量的数据类型。
"example_forward_cuda":操作的名称,用于错误信息。
[&]:捕获外部变量的 lambda 表达式。
example_kernel<scalar_t><<<blocks, threads>>>(input.data<scalar_t>(), size):根据调度的数据类型调用相应的 CUDA 内核。
3. 在 Python 中调用
最后,我们在 Python 中加载和调用这个扩展:

import torch
from torch.utils.cpp_extension import load

# JIT 编译并加载 C++ 扩展
example_cpp = load(name="example_cpp", sources=["example.cpp"], verbose=True)

# 创建输入张量
input = torch.randn(1024, device='cuda', dtype=torch.float32)

# 调用前向传播函数
output = example_cpp.forward(input)

print(output)


网站公告

今日签到

点亮在社区的每一天
去签到