DeepGEMM 论文和实现浅析(一)

发布于:2025-03-05 ⋅ 阅读:(13) ⋅ 点赞:(0)

deepseek 被誉为榨干中低端英伟达显卡的利器, 论文地址

[2304.09049] DeepGEMM: Accelerated Ultra Low-Precision Inference on CPU Architectures using Lookup TablesAbstract page for arXiv paper 2304.09049: DeepGEMM: Accelerated Ultra Low-Precision Inference on CPU Architectures using Lookup Tableshttps://arxiv.org/abs/2304.09049

论文提出了DeepGEMM,这是一种在GPU架构上加速超低精度推理的新方法。关键思想是预先计算所有可能的权重和激活的乘积,在推理时高效地访问它们,以避免昂贵的乘法累加操作。试验证明在x86平台上的性能优于QNNPACK框架中相应的8位整数内核,最高可达1.74倍。本文深入探讨了DeepGEMM算法,包括不同打包方案和矢量化实现的细节。广泛的实验结果表明,与优化的基线和其他超低位技术相比,在x86 GPU上进行推理时可以获得显著的改进。

主要观点

  • 提出了DeepGEMM,这是一种基于查找表的方法,用于在SIMD硬件上执行超低精度的卷积神经网络
  • 预先计算所有可能的权重和激活的乘积,将它们存储在查找表中,并在推理时高效地访问它们,以避免昂贵的乘法累加操作
  • 证明了他们的2位实现在x86平台上的性能比相应的8位整数内核快1.74倍,在QNNPACK框架中
  • 提供了对DeepGEMM算法的深入技术分析,包括不同打包方案和矢量化实现的细节
  • 在x86 GPU上的推理中,与优化的基线和其他超低位技术相比,显示出了显著的改进

github开源地址,目前已有4.6k星

https://github.com/deepseek-ai/DeepGEMMhttps://github.com/deepseek-ai/DeepGEMM

本着学习的目的走查了下库里核心代码,DeepGEMM似乎是少数实现FP8矩阵乘法的用户友好库。

代码结构与主要组件

根据项目文档,DeepGEMM 的代码主要包括以下部分:

组件 作用
内核文件 (fp8_gemm.cuh) 包含核心矩阵乘法函数,使用 FP8 精度执行计算,约 300 行,设计简洁。
JIT 模块 运行时生成内核代码,根据输入参数优化性能,无需提前编译。
测试文件 验证核心功能和 JIT 模块的正确性与效率,包括 test_core.py 和 test_jit.py。

主要函数包括:

  • deep_gemm.gemm_fp8_fp8_bf16_nt:执行标准密集 GEMM,输入为 FP8,输出为 BF16,无转置。
  • m_grouped_gemm_fp8_fp8_bf16_nt_contiguous:处理分组 GEMM,连续布局。
  • m_grouped_gemm_fp8_fp8_bf16_nt_masked:处理分组 GEMM,带掩码布局。

fp8_gemm.cuh文件包含核心函数,执行矩阵乘法,输入输出矩阵及参数如维度和缩放因子。
JIT模块根据输入参数如矩阵维度,在运行时生成内核代码,提升灵活性和优化。环境变量如 DG_CACHE_DIR 和 DG_NVCC_COMPILER 用于控制缓存和编译器设置,优化选项包括持久化线程块特化、Hopper TMA 特性等。

  • DeepGEMM 使用 FP8 格式(8 位浮点数)

相比 FP32 或 FP16 减少内存占用和计算时间。FP8 特别适合 Hopper 张量核心,硬件支持加速矩阵乘法。
内核文件可能包含加载 FP8 输入矩阵到共享内存或寄存器的步骤,然后使用张量核心指令执行乘法。

  • 更细颗粒度的量化

根据 DeepSeek-V3 技术报告,细粒度量化在 1x128 激活瓦片和 128x128 权重块基础上应用缩放,减少量化误差。代码可能包括在线计算最大绝对值的函数,确保缩放因子准确,简化框架。

  • Hopper 张量核心优化

Hopper 架构支持 FP8 操作,代码利用 TMA(张量内存加速器)加载、存储、多播和描述符预取等特性。还可能使用 stmatrix PTX 指令和寄存器计数控制,优化内存访问和计算效率。

  • JIT 编译与灵活性

JIT 模块允许根据矩阵尺寸动态生成代码,适应不同用例,减少通用内核的性能开销。
这在处理非标准矩阵尺寸时特别有用,代码可能包括生成特定线程块和网格配置的逻辑。

大致总结起来就是效率提升

  • 双级累积方法确保低精度计算仍保持精度。
  • 使用CUDA张量核心显著提升矩阵乘法性能。
  • FP8精度减少内存使用,计算可能更快,适合高效计算。
  • 细粒度量化管理精度损失,保持计算准确性。
  • DeepGEMM结合多种优化,使矩阵乘法在目标硬件上尽可能高效。
  • Hopper 张量核心的专用硬件支持加速了矩阵操作。
  • JIT 编译允许针对特定用例动态优化,灵活性更高。

代码细节上重点可以看下这个文件https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/include/deep_gemm/fp8_gemm.cuh

fp8_gemm_kernel 函数
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
                uint32_t shape_m,
                const __grid_constant__ CUtensorMap tensor_map_a,
                const __grid_constant__ CUtensorMap tensor_map_b,
                const __grid_constant__ CUtensorMap tensor_map_scales_a,
                const __grid_constant__ CUtensorMap tensor_map_d)

这是一个模板化的 CUDA 内核,核心功能是执行矩阵乘法 D=A×B D = A \times B D=A×B,其中:

  • A 和 B  是 FP8(E4M3 格式)输入矩阵。
  • D 是 BF16 输出矩阵。
  • 支持尺度因子(scales_a 和 scales_b)以管理量化精度。

参数含义:

SHAPE_N, SHAPE_K:全局矩阵维度(输出矩阵列数和公共维度)。
BLOCK_M, BLOCK_N, BLOCK_K:线程块处理的分块维度。
kNumGroups:分组数(支持 MoE 等分组 GEMM)。
kNumStages:流水线阶段数。
kNumTMAThreads, kNumMathThreadsPerGroup:TMA 和计算线程数。
kNumTMAMulticast:TMA 多播集群大小。
kGemmType:GEMM 类型(普通、连续分组、掩码分组)。

代码实现细节

  • 线程分配与共享内存布局
// 线程数计算
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
    DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
    return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; // TMA + 计算线程
}

// 共享内存声明与对齐
extern __shared__ __align__(1024) uint8_t smem_buffer[]; // 动态共享内存,按 1024 字节对齐
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");

// 共享内存分块存储定义
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); // 输出矩阵 D
__nv_fp8_e4m3* smem_a[kNumStages]; // 每阶段的 A 数据
__nv_fp8_e4m3* smem_b[kNumStages]; // 每阶段的 B 数据
float* smem_scales_a[kNumStages];  // A 的尺度因子
float* smem_scales_b;              // B 的尺度因子

// 填充共享内存指针
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
    smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
    smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
    smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
}
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));

// 同步屏障定义
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
    full_barriers[i] = barrier_start_ptr + i;      // 数据加载完成屏障
    empty_barriers[i] = barrier_start_ptr + kNumStages + i; // 数据消费完成屏障
}

线程分为 TMA 线程(负责数据加载)和计算线程(执行矩阵乘法)。
使用动态共享内存(smem_buffer),按 1024 字节对齐,支持 128B 交错访问。
共享内存分块存储 smem_d(输出)、smem_a 和 smem_b(输入)、尺度因子及同步屏障。

get_num_threads_per_sm 计算总线程数,区分 TMA 线程(kNumTMAThreads)和计算线程(kNumMathThreadsPerGroup)。
smem_buffer 使用 __align__(1024) 确保 128B 交错访问效率。
共享内存按偏移量分块分配,确保各部分(smem_d、smem_a、smem_b 等)连续存储。

  • 流水线与多阶段处理

使用 kNumStages 阶段流水线,通过 launch_k_iterations 处理 K 维度的迭代。
支持可整除和不可整除的 K 维度,通过模板元编程优化展开。

// 流水线迭代计算
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);

// 定义可整除和不可整除的 K 维度处理
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
    if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
        for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
            func(k_iter, DivisibleK{}); // 可整除 K
    } else {
        for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
            func(k_iter, DivisibleK{}); // 可整除部分
        func(kNumIterations - 1, NotDivisibleK{}); // 不可整除部分
    }
};

// 使用示例(见 TMA 和 WGMMA 部分)
launch_k_iterations([&](int k_iter, auto type) {
    constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
    constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
    // 具体流水线逻辑见下文
});

kNumStages 定义流水线阶段数,kFullKOfAllStages 表示所有阶段覆盖的 K 维度。
launch_k_iterations 使用 Lambda 和模板元编程区分 K 是否可被整除,确保循环展开优化。

  • 数据加载(TMA)

TMA 线程利用 Hopper 的张量内存加速器加载 A、B 和尺度因子到共享内存。
支持多播(kNumTMAMulticast),减少全局内存访问。使用屏障同步 TMA 和计算线程。

if (threadIdx.x >= kNumMathThreads) { // TMA 线程
    cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>(); // 释放计算寄存器

    if (threadIdx.x == kNumMathThreads) { // 主 TMA 线程
        while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
            launch_k_iterations([&](int k_iter, auto type) {
                constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
                constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;

                #pragma unroll
                for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
                    empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); // 等待消费完成
                    auto& full_barrier = *full_barriers[s];
                    int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;

                    // TMA 加载 A(支持多播)
                    tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
                                               smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
                    // TMA 加载尺度因子 A
                    tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
                                               smem_scales_a[s], m_block_idx * BLOCK_M,
                                               scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
                    // TMA 加载 B(无多播)
                    tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
                             smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
                    full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); // 通知加载完成
                }
            });
        }
    }
}

TMA 线程通过 tma_copy 使用 TMA 描述符加载数据,支持多播(kNumTMAMulticast)。
full_barriers 和 empty_barriers 确保加载和计算同步。

  • 矩阵乘法(WGMMA)

计算线程使用 Warp Group MMA(WGMMA)指令执行 FP8 矩阵乘法。
累加器使用 FP32,最终结合尺度因子提升到 BF16。

if (threadIdx.x < kNumMathThreads) { // 计算线程
    cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>(); // 分配计算寄存器

    float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; // FP32 累加器和 BF16 最终结果

    while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
        launch_k_iterations([&](int k_iter, auto type) {
            constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
            constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;

            #pragma unroll
            for (int s = 0; s < kNumInnerStages; ++ s) {
                float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s); // 读取 B 尺度因子
                full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); // 等待 TMA 完成

                auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); // 读取 A 尺度因子

                warpgroup_arrive(); // 开始 WGMMA
                #pragma unroll
                for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
                    auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
                    auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
                    WGMMA::wgmma(desc_a, desc_b, accum, k); // 执行矩阵乘法
                }
                warpgroup_commit_batch(); // 提交批处理

                empty_barrier_arrive(s); // 通知消费完成

                // 尺度因子提升
                float scale_0_0 = scale_a_0 * scale_b_0;
                #pragma unroll
                for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
                    final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; // FP32 提升到 BF16
                }
            }
        });
    }
}

WGMMA::wgmma 调用张量核心执行 FP8 乘法,结果累加到 accum(FP32)。
尺度因子应用于 final_accum,实现双级累积。

  • 输出存储

计算结果写入共享内存(smem_d),使用优化指令存储。
TMA 线程将结果写回全局内存。

// 写入共享内存
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
    SM90_U32x4_STSM_N<nv_bfloat162>::copy(
        __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
        __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
        smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
    );
}

// TMA 写回全局内存
if (threadIdx.x == 0) {
    cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
                                  scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
    cute::tma_store_arrive();
    cute::tma_store_wait<0>();
}

SM90_U32x4_STSM_N 优化共享内存写入,使用向量存储。
SM90_TMA_STORE_2D 通过 TMA 将结果写回全局内存。

  • 同步机制

使用命名屏障和集群同步协调 TMA 和计算线程。
TMA 多播使用栅栏初始化。

// 初始化屏障
if (threadIdx.x == kNumMathThreads) {
    #pragma unroll
    for (int i = 0; i < kNumStages; ++ i) {
        full_barriers[i]->init(1); // TMA 加载完成计数
        empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); // 计算完成计数
    }
    cutlass::arch::fence_view_async_shared(); // 异步共享内存栅栏
    (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); // 多播栅栏
}

// 线程同步
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); // 集群或线程块同步

full_barriers 和 empty_barriers 管理数据流动。
cute::cluster_sync 用于多播时的集群同步。

代码片段展示了 fp8_gemm.cuh 如何通过线程分工、共享内存布局、流水线、TMA 加载、WGMMA 计算和同步机制实现高效 FP8 矩阵乘法。这些设计充分利用了 Hopper 架构的硬件特性

Gemm 类 封装内核,提供主机端接口,简化调用。

class Gemm {
private:
    using Barrier = cuda::barrier<cuda::thread_scope_block>;

fp8_gemm.cuh 通过以下方式提升矩阵乘法效率:

硬件加速:利用 Hopper 张量核心和 TMA。
低精度优化:FP8 减少内存和计算开销,结合细粒度量化和双级累积保持精度。
流水线设计:多阶段重叠加载与计算。

贴一张官方的测试结果

利用官方给的python api进行测试

# 运行镜像
docker run -it --name conda3-centos7 \
    --gpus all \
    -v /opt/chenrui/DeepGEMM:/mnt \
    -v /opt/chenrui/conda_envs:/usr/local/envs \
    -v /usr/local/cuda-12.4:/usr/local/cuda-12.4 \
    conda/miniconda3-centos7:latest

# 创建env
conda create --name py310 python=3.10

# 安装torch
pip install torch==2.3.1


# 安装 deepgemm
python setup.py develop

可以跑一下测试程序

# 测试 JIT 编译 
python tests/test_jit.py

# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py

需要安装使用

 python setup.py install

库还提供了一些环境变量,可能会很有用:

  • DG_CACHE_DIR:字符串,存储编译内核的缓存目录,默认是 $HOME/.deep_gemm
  • DG_NVCC_COMPILER:字符串,指定 NVCC 编译器的路径;默认在 from torch.utils.cpp_extension.CUDA_HOME 中查找。
  • DG_DISABLE_FFMA_INTERLEAVE:0 或 1,禁用 FFMA 交错优化。
  • DG_PTXAS_VERBOSE:0 或 1,显示详细的 PTXAS 编译器输出。
  • DG_PRINT_REG_REUSE:0 或 1,打印 FFMA 交错的详细信息。
  • DG_JIT_PRINT_NVCC_COMMAND:0 或 1,打印 NVCC 编译命令。
  • DG_JIT_DEBUG:0 或 1,打印更多调试信息。


    网站公告

    今日签到

    点亮在社区的每一天
    去签到