一 宏
1.1 AT_DISPATCH_FLOATING_TYPES
AT_DISPATCH_FLOATING_TYPES 和AT_DISPATCH_ALL_TYPES 的作用-CSDN博客
AT_DISPATCH_FLOATING_TYPES 宏主要用于以下目的:
数据类型调度:根据输入张量的数据类型选择合适的数据类型进行处理。
模板编程:结合 C++ 模板编程,根据不同的数据类型生成不同的代码路径。
代码简化:减少手动写类型检查和类型转换代码的繁琐过程。
使用示例
假设我们有一个简单的 CUDA 内核函数 example_kernel,它对输入张量进行某种操作。我们希望这个内核函数可以处理 float 和 double 类型的数据。以下是如何使用 AT_DISPATCH_FLOATING_TYPES 来实现这个目标。
1. 定义 CUDA 内核
首先,我们定义一个简单的 CUDA 内核函数:
template <typename scalar_t>
__global__ void example_kernel(scalar_t* data, int64_t size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
data[index] *= 2; // 例如,简单地将每个元素乘以2
}
}
2. 定义 C++ 函数并使用 AT_DISPATCH_FLOATING_TYPES
接下来,我们定义一个 C++ 函数,使用 AT_DISPATCH_FLOATING_TYPES 来调度数据类型,并调用相应的 CUDA 内核:
#include <torch/extension.h>
#include <vector>
std::vector<torch::Tensor> example_forward(torch::Tensor input) {
const auto size = input.size(0);
auto output = torch::zeros_like(input);
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(input.type(), "example_forward_cuda", ([&] {
example_kernel<scalar_t><<<blocks, threads>>>(
input.data<scalar_t>(), size);
}));
return {output};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &example_forward, "Example forward");
}
在这段代码中:
AT_DISPATCH_FLOATING_TYPES(input.type(), "example_forward_cuda", ([&] { ... })):
input.type():获取输入张量的数据类型。
"example_forward_cuda":操作的名称,用于错误信息。
[&]:捕获外部变量的 lambda 表达式。
example_kernel<scalar_t><<<blocks, threads>>>(input.data<scalar_t>(), size):根据调度的数据类型调用相应的 CUDA 内核。
3. 在 Python 中调用
最后,我们在 Python 中加载和调用这个扩展:
import torch
from torch.utils.cpp_extension import load
# JIT 编译并加载 C++ 扩展
example_cpp = load(name="example_cpp", sources=["example.cpp"], verbose=True)
# 创建输入张量
input = torch.randn(1024, device='cuda', dtype=torch.float32)
# 调用前向传播函数
output = example_cpp.forward(input)
print(output)