deepseek 被誉为榨干中低端英伟达显卡的利器, 论文地址
论文提出了DeepGEMM,这是一种在GPU架构上加速超低精度推理的新方法。关键思想是预先计算所有可能的权重和激活的乘积,在推理时高效地访问它们,以避免昂贵的乘法累加操作。试验证明在x86平台上的性能优于QNNPACK框架中相应的8位整数内核,最高可达1.74倍。本文深入探讨了DeepGEMM算法,包括不同打包方案和矢量化实现的细节。广泛的实验结果表明,与优化的基线和其他超低位技术相比,在x86 GPU上进行推理时可以获得显著的改进。
主要观点
- 提出了DeepGEMM,这是一种基于查找表的方法,用于在SIMD硬件上执行超低精度的卷积神经网络
- 预先计算所有可能的权重和激活的乘积,将它们存储在查找表中,并在推理时高效地访问它们,以避免昂贵的乘法累加操作
- 证明了他们的2位实现在x86平台上的性能比相应的8位整数内核快1.74倍,在QNNPACK框架中
- 提供了对DeepGEMM算法的深入技术分析,包括不同打包方案和矢量化实现的细节
- 在x86 GPU上的推理中,与优化的基线和其他超低位技术相比,显示出了显著的改进
github开源地址,目前已有4.6k星
https://github.com/deepseek-ai/DeepGEMMhttps://github.com/deepseek-ai/DeepGEMM
本着学习的目的走查了下库里核心代码,DeepGEMM似乎是少数实现FP8矩阵乘法的用户友好库。
代码结构与主要组件
根据项目文档,DeepGEMM 的代码主要包括以下部分:
组件 | 作用 |
---|---|
内核文件 (fp8_gemm.cuh) | 包含核心矩阵乘法函数,使用 FP8 精度执行计算,约 300 行,设计简洁。 |
JIT 模块 | 运行时生成内核代码,根据输入参数优化性能,无需提前编译。 |
测试文件 | 验证核心功能和 JIT 模块的正确性与效率,包括 test_core.py 和 test_jit.py。 |
主要函数包括:
- deep_gemm.gemm_fp8_fp8_bf16_nt:执行标准密集 GEMM,输入为 FP8,输出为 BF16,无转置。
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous:处理分组 GEMM,连续布局。
- m_grouped_gemm_fp8_fp8_bf16_nt_masked:处理分组 GEMM,带掩码布局。
fp8_gemm.cuh文件包含核心函数,执行矩阵乘法,输入输出矩阵及参数如维度和缩放因子。
JIT模块根据输入参数如矩阵维度,在运行时生成内核代码,提升灵活性和优化。环境变量如 DG_CACHE_DIR 和 DG_NVCC_COMPILER 用于控制缓存和编译器设置,优化选项包括持久化线程块特化、Hopper TMA 特性等。
- DeepGEMM 使用 FP8 格式(8 位浮点数)
相比 FP32 或 FP16 减少内存占用和计算时间。FP8 特别适合 Hopper 张量核心,硬件支持加速矩阵乘法。
内核文件可能包含加载 FP8 输入矩阵到共享内存或寄存器的步骤,然后使用张量核心指令执行乘法。
- 更细颗粒度的量化
根据 DeepSeek-V3 技术报告,细粒度量化在 1x128 激活瓦片和 128x128 权重块基础上应用缩放,减少量化误差。代码可能包括在线计算最大绝对值的函数,确保缩放因子准确,简化框架。
Hopper 张量核心优化
Hopper 架构支持 FP8 操作,代码利用 TMA(张量内存加速器)加载、存储、多播和描述符预取等特性。还可能使用 stmatrix PTX 指令和寄存器计数控制,优化内存访问和计算效率。
- JIT 编译与灵活性:
JIT 模块允许根据矩阵尺寸动态生成代码,适应不同用例,减少通用内核的性能开销。
这在处理非标准矩阵尺寸时特别有用,代码可能包括生成特定线程块和网格配置的逻辑。
大致总结起来就是效率提升
- 双级累积方法确保低精度计算仍保持精度。
- 使用CUDA张量核心显著提升矩阵乘法性能。
- FP8精度减少内存使用,计算可能更快,适合高效计算。
- 细粒度量化管理精度损失,保持计算准确性。
- DeepGEMM结合多种优化,使矩阵乘法在目标硬件上尽可能高效。
- Hopper 张量核心的专用硬件支持加速了矩阵操作。
JIT 编译允许针对特定用例动态优化,灵活性更高。
代码细节上重点可以看下这个文件https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/include/deep_gemm/fp8_gemm.cuh
fp8_gemm_kernel 函数
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d)
这是一个模板化的 CUDA 内核,核心功能是执行矩阵乘法 D=A×B D = A \times B D=A×B,其中:
- A 和 B 是 FP8(E4M3 格式)输入矩阵。
- D 是 BF16 输出矩阵。
- 支持尺度因子(scales_a 和 scales_b)以管理量化精度。
参数含义:
SHAPE_N, SHAPE_K:全局矩阵维度(输出矩阵列数和公共维度)。
BLOCK_M, BLOCK_N, BLOCK_K:线程块处理的分块维度。
kNumGroups:分组数(支持 MoE 等分组 GEMM)。
kNumStages:流水线阶段数。
kNumTMAThreads, kNumMathThreadsPerGroup:TMA 和计算线程数。
kNumTMAMulticast:TMA 多播集群大小。
kGemmType:GEMM 类型(普通、连续分组、掩码分组)。
代码实现细节
线程分配与共享内存布局
// 线程数计算
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; // TMA + 计算线程
}
// 共享内存声明与对齐
extern __shared__ __align__(1024) uint8_t smem_buffer[]; // 动态共享内存,按 1024 字节对齐
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// 共享内存分块存储定义
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); // 输出矩阵 D
__nv_fp8_e4m3* smem_a[kNumStages]; // 每阶段的 A 数据
__nv_fp8_e4m3* smem_b[kNumStages]; // 每阶段的 B 数据
float* smem_scales_a[kNumStages]; // A 的尺度因子
float* smem_scales_b; // B 的尺度因子
// 填充共享内存指针
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
}
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
// 同步屏障定义
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i; // 数据加载完成屏障
empty_barriers[i] = barrier_start_ptr + kNumStages + i; // 数据消费完成屏障
}
线程分为 TMA 线程(负责数据加载)和计算线程(执行矩阵乘法)。
使用动态共享内存(smem_buffer),按 1024 字节对齐,支持 128B 交错访问。
共享内存分块存储 smem_d(输出)、smem_a 和 smem_b(输入)、尺度因子及同步屏障。
get_num_threads_per_sm 计算总线程数,区分 TMA 线程(kNumTMAThreads)和计算线程(kNumMathThreadsPerGroup)。
smem_buffer 使用 __align__(1024) 确保 128B 交错访问效率。
共享内存按偏移量分块分配,确保各部分(smem_d、smem_a、smem_b 等)连续存储。
流水线与多阶段处理
使用 kNumStages 阶段流水线,通过 launch_k_iterations 处理 K 维度的迭代。
支持可整除和不可整除的 K 维度,通过模板元编程优化展开。
// 流水线迭代计算
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
// 定义可整除和不可整除的 K 维度处理
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
func(k_iter, DivisibleK{}); // 可整除 K
} else {
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}); // 可整除部分
func(kNumIterations - 1, NotDivisibleK{}); // 不可整除部分
}
};
// 使用示例(见 TMA 和 WGMMA 部分)
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
// 具体流水线逻辑见下文
});
kNumStages 定义流水线阶段数,kFullKOfAllStages 表示所有阶段覆盖的 K 维度。
launch_k_iterations 使用 Lambda 和模板元编程区分 K 是否可被整除,确保循环展开优化。
数据加载(TMA)
TMA 线程利用 Hopper 的张量内存加速器加载 A、B 和尺度因子到共享内存。
支持多播(kNumTMAMulticast),减少全局内存访问。使用屏障同步 TMA 和计算线程。
if (threadIdx.x >= kNumMathThreads) { // TMA 线程
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>(); // 释放计算寄存器
if (threadIdx.x == kNumMathThreads) { // 主 TMA 线程
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); // 等待消费完成
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
// TMA 加载 A(支持多播)
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
// TMA 加载尺度因子 A
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// TMA 加载 B(无多播)
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); // 通知加载完成
}
});
}
}
}
TMA 线程通过 tma_copy 使用 TMA 描述符加载数据,支持多播(kNumTMAMulticast)。
full_barriers 和 empty_barriers 确保加载和计算同步。
矩阵乘法(WGMMA)
计算线程使用 Warp Group MMA(WGMMA)指令执行 FP8 矩阵乘法。
累加器使用 FP32,最终结合尺度因子提升到 BF16。
if (threadIdx.x < kNumMathThreads) { // 计算线程
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>(); // 分配计算寄存器
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; // FP32 累加器和 BF16 最终结果
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s); // 读取 B 尺度因子
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); // 等待 TMA 完成
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); // 读取 A 尺度因子
warpgroup_arrive(); // 开始 WGMMA
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k); // 执行矩阵乘法
}
warpgroup_commit_batch(); // 提交批处理
empty_barrier_arrive(s); // 通知消费完成
// 尺度因子提升
float scale_0_0 = scale_a_0 * scale_b_0;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; // FP32 提升到 BF16
}
}
});
}
}
WGMMA::wgmma 调用张量核心执行 FP8 乘法,结果累加到 accum(FP32)。
尺度因子应用于 final_accum,实现双级累积。
输出存储
计算结果写入共享内存(smem_d),使用优化指令存储。
TMA 线程将结果写回全局内存。
// 写入共享内存
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
);
}
// TMA 写回全局内存
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
SM90_U32x4_STSM_N 优化共享内存写入,使用向量存储。
SM90_TMA_STORE_2D 通过 TMA 将结果写回全局内存。
同步机制
使用命名屏障和集群同步协调 TMA 和计算线程。
TMA 多播使用栅栏初始化。
// 初始化屏障
if (threadIdx.x == kNumMathThreads) {
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1); // TMA 加载完成计数
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); // 计算完成计数
}
cutlass::arch::fence_view_async_shared(); // 异步共享内存栅栏
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); // 多播栅栏
}
// 线程同步
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); // 集群或线程块同步
full_barriers 和 empty_barriers 管理数据流动。
cute::cluster_sync 用于多播时的集群同步。
代码片段展示了 fp8_gemm.cuh 如何通过线程分工、共享内存布局、流水线、TMA 加载、WGMMA 计算和同步机制实现高效 FP8 矩阵乘法。这些设计充分利用了 Hopper 架构的硬件特性
Gemm 类 封装内核,提供主机端接口,简化调用。
class Gemm {
private:
using Barrier = cuda::barrier<cuda::thread_scope_block>;
fp8_gemm.cuh 通过以下方式提升矩阵乘法效率:
硬件加速:利用 Hopper 张量核心和 TMA。
低精度优化:FP8 减少内存和计算开销,结合细粒度量化和双级累积保持精度。
流水线设计:多阶段重叠加载与计算。
贴一张官方的测试结果
利用官方给的python api进行测试
# 运行镜像
docker run -it --name conda3-centos7 \
--gpus all \
-v /opt/chenrui/DeepGEMM:/mnt \
-v /opt/chenrui/conda_envs:/usr/local/envs \
-v /usr/local/cuda-12.4:/usr/local/cuda-12.4 \
conda/miniconda3-centos7:latest
# 创建env
conda create --name py310 python=3.10
# 安装torch
pip install torch==2.3.1
# 安装 deepgemm
python setup.py develop
可以跑一下测试程序
# 测试 JIT 编译
python tests/test_jit.py# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py
需要安装使用
python setup.py install
库还提供了一些环境变量,可能会很有用:
DG_CACHE_DIR
:字符串,存储编译内核的缓存目录,默认是$HOME/.deep_gemm
。DG_NVCC_COMPILER
:字符串,指定 NVCC 编译器的路径;默认在from torch.utils.cpp_extension.CUDA_HOME
中查找。DG_DISABLE_FFMA_INTERLEAVE
:0 或 1,禁用 FFMA 交错优化。DG_PTXAS_VERBOSE
:0 或 1,显示详细的 PTXAS 编译器输出。DG_PRINT_REG_REUSE
:0 或 1,打印 FFMA 交错的详细信息。DG_JIT_PRINT_NVCC_COMMAND
:0 或 1,打印 NVCC 编译命令。DG_JIT_DEBUG
:0 或 1,打印更多调试信息。