layernorm backward CUDA优化分析

发布于:2025-08-06 ⋅ 阅读:(13) ⋅ 点赞:(0)

简述

本文面向拥有CUDA知识背景并有快速实现layernorm backward需求的读者,若想详细了解layernorm backward计算原理、优化细节请移步参考链接中的文章,本文更侧重于代码实现。如有高见请不吝赐教,谢谢!

很多大佬已经对layernorm_bwd原理、优化方法有过详细讲解(参考链接),这里不再赘述,只是对layernorm_bwd常用优化方法代码复现。

1. layernorm_bwd算法原理及cpu实现

  • layernorm_bwd公式推导:
    在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

template<typename T, typename T_ACC>
void layernorm_backward_cpu(T* dinput, T* dweight, T* dbias, T* doutput,
                        T* input, T* weight, T_ACC* mean, T_ACC* rstd,
                        const int batch, const int seq_len, const int hidden_dim)
{
    for(int b=0; b<batch; b++){
        for(int i=0; i<seq_len; i++){
            const T* doutput_offset = doutput + b * seq_len * hidden_dim + i * hidden_dim;
            T* dinput_offset = dinput + b * seq_len * hidden_dim + i * hidden_dim;
            const T* input_offset = input + b * seq_len * hidden_dim + i * hidden_dim;
            const T_ACC mean_val = mean[b * seq_len + i];
            const T_ACC rstd_val = rstd[b * seq_len + i]; 

            T dnorm_mean = 0.0f;
            T dnorm_norm_mean = 0.0f;
            for(int j = 0; j<hidden_dim; j++){
                T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
                T dnorm_i = weight[j] * doutput_offset[j];
                dnorm_mean += dnorm_i;
                dnorm_norm_mean += dnorm_i * norm_bti;
            }
            dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
            dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);

            for(int j = 0; j<hidden_dim; j++){
                T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
                T dnorm_i = weight[j] * doutput_offset[j];

                // gradient to bias
                dbias[j] += doutput_offset[j];

                // gradient to weight
                dweight[j] += norm_bti * doutput_offset[j];

                // gradient to input
                T dval = 0.0f;
                dval += dnorm_i;
                dval -= dnorm_mean;
                dval -= norm_bti * dnorm_norm_mean;
                dval *= rstd_val;
                dinput_offset[j] += dval;
            }
        }
    }
}

2. layernorm_bwd cuda优化方法及实现

2.1 layernorm_bwd

  • 优化方法:v1版本是每个线程计算一行数据,即一共有batch*seq_len个线程,每个线程循环计算hidden_dim个数据;
template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel1(T* dinput, T* dweight, T* dbias, const T* doutput,
                        T* input, T* weight, T_ACC* mean, T_ACC* rstd,
                        const int batch, const int seq_len, const int hidden_dim)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if(idx < batch * seq_len){
        const T* doutput_offset = doutput + idx * hidden_dim;
        T* dinput_offset = dinput + idx * hidden_dim;
        const T* input_offset = input + idx * hidden_dim;
        const T_ACC mean_val = mean[idx];
        const T_ACC rstd_val = rstd[idx]; 

        T dnorm_mean = 0.0f;
        T dnorm_norm_mean = 0.0f;
        for(int i=0; i<hidden_dim; i++){
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * doutput_offset[i];
            dnorm_mean += dnorm_i;
            dnorm_norm_mean += dnorm_i * norm_bti;
        }

        dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
        dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);

        for(int i=0; i<hidden_dim; i++){
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * doutput_offset[i];

            // gradient to bias
            atomicAdd(&(dbias[i]), doutput_offset[i]);

            // gradient to weight
            atomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);

            // gradient to input
            T dval = 0.0f;
            dval += dnorm_i;
            dval -= dnorm_mean;
            dval -= norm_bti * dnorm_norm_mean;
            dval *= rstd_val;
            dinput_offset[i] += dval;
        }
    }
}
	dim3 block(256, 1);
    dim3 grid((batch * seq_len) / block.x, 1);
    util::print_cuda_cfg(grid, block);
    layernorm_backward_kernel1<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, 
                                input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.2 layernorm_fwd_v2

  • 优化方法:v2版本是每个warp计算一行数据,即一共有batch*seq_len个warp,每个warp循环计算hidden_dim个数据;warp内部会通过线程束洗牌指令计算出max值。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unroll
    for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}

template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel2(T* dinput, T* dweight, T* dbias, const T* doutput,
                        T* input, T* weight, T_ACC* mean, T_ACC* rstd,
                        const int batch, const int seq_len, const int hidden_dim)
{
    int tx = threadIdx.x;
    int by = blockIdx.y;
    if(by < batch * seq_len){
        const T* doutput_offset = doutput + by * hidden_dim;
        T* dinput_offset = dinput + by * hidden_dim;
        const T* input_offset = input + by * hidden_dim;
        const T_ACC mean_val = mean[by];
        const T_ACC rstd_val = rstd[by]; 

        T dnorm_mean = 0.0f;
        T dnorm_norm_mean = 0.0f;
        for(int i=tx; i<hidden_dim; i+=blockDim.x){
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * doutput_offset[i];
            dnorm_mean += dnorm_i;
            dnorm_norm_mean += dnorm_i * norm_bti;
        }
        dnorm_mean = warpReduceSum<T>(dnorm_mean);
        dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);

        dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
        dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);

        for(int i=tx; i<hidden_dim; i+=blockDim.x){
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * doutput_offset[i];

            // gradient to bias
            atomicAdd(&(dbias[i]), doutput_offset[i]);

            // gradient to weight
            atomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);

            // gradient to input
            T dval = 0.0f;
            dval += dnorm_i;
            dval -= dnorm_mean;
            dval -= norm_bti * dnorm_norm_mean;
            dval *= rstd_val;
            dinput_offset[i] += dval;
        }
    }
}
	dim3 block(32, 1);
    dim3 grid(1, batch * seq_len);
    layernorm_backward_kernel2<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, 
                                input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.3 layernorm_bwd_v3

  • 优化方法:基于v2版本仍采用32个线程计算一行数据,但在此版本中将doutput加载至smem中,避免对global memory多次访问。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unroll
    for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}

template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel3(T* dinput, T* dweight, T* dbias, const T* doutput,
                        T* input, T* weight, T_ACC* mean, T_ACC* rstd,
                        const int batch, const int seq_len, const int hidden_dim)
{
    int tx = threadIdx.x;
    int by = blockIdx.y;
    extern __shared__ unsigned char tmp_smem[];
    T *smem = reinterpret_cast<T *>(tmp_smem);
    if(by < batch * seq_len){
        const T* doutput_offset = doutput + by * hidden_dim;
        T* dinput_offset = dinput + by * hidden_dim;
        const T* input_offset = input + by * hidden_dim;
        const T_ACC mean_val = mean[by];
        const T_ACC rstd_val = rstd[by]; 

        T dnorm_mean = 0.0f;
        T dnorm_norm_mean = 0.0f;
        for(int i=tx; i<hidden_dim; i+=blockDim.x){
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * doutput_offset[i];
            dnorm_mean += dnorm_i;
            dnorm_norm_mean += dnorm_i * norm_bti;
        }
        dnorm_mean = warpReduceSum<T>(dnorm_mean);
        dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);

        dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
        dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);

        for(int i=tx; i<hidden_dim; i+=blockDim.x){
            smem[tx] = doutput_offset[i];
            __syncthreads();
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * smem[tx];

            // gradient to bias
            atomicAdd(&(dbias[i]), smem[tx]);

            // gradient to weight
            atomicAdd(&(dweight[i]), norm_bti * smem[tx]);

            // gradient to input
            T dval = 0.0f;
            dval += dnorm_i;
            dval -= dnorm_mean;
            dval -= norm_bti * dnorm_norm_mean;
            dval *= rstd_val;
            dinput_offset[i] += dval;
        }
    }
}

	dim3 block(32, 1);
    dim3 grid(1, batch * seq_len);
    size_t smem_size = sizeof(T) * block.x;
    layernorm_backward_kernel3<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, 
                                input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.4 layernorm_fwd_v4

  • 优化方法:基于v3版本,v4版本让1024个线程循环计算一行。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}

template<typename T>
__device__ __inline__ T blockReduceSum(T val){
    __shared__ T shared[WARP_SIZE];
    __shared__ T ret;

    int warp_id = threadIdx.x / WARP_SIZE;
    int lane_id = threadIdx.x % WARP_SIZE;

    val = warpReduceSum(val);

    if(lane_id == 0){
        shared[warp_id] = val;
    }
    __syncthreads();

    val = (threadIdx.x < WARP_SIZE) ? shared[threadIdx.x] : (T)(0.0f);
    val = warpReduceSum(val);
    if (threadIdx.x == 0)
    {
        ret = val;
    }
    __syncthreads();

    return ret;
}

template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel4(T* dinput, T* dweight, T* dbias, const T* doutput,
                        T* input, T* weight, T_ACC* mean, T_ACC* rstd,
                        const int batch, const int seq_len, const int hidden_dim)
{
    int tx = threadIdx.x;
    int by = blockIdx.y;
    extern __shared__ unsigned char tmp_smem[];
    T *smem = reinterpret_cast<T *>(tmp_smem);
    if(by < batch * seq_len){
        const T* doutput_offset = doutput + by * hidden_dim;
        T* dinput_offset = dinput + by * hidden_dim;
        const T* input_offset = input + by * hidden_dim;
        const T_ACC mean_val = mean[by];
        const T_ACC rstd_val = rstd[by]; 

        T dnorm_mean = 0.0f;
        T dnorm_norm_mean = 0.0f;
        for(int i=tx; i<hidden_dim; i+=blockDim.x){
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * doutput_offset[i];
            dnorm_mean += dnorm_i;
            dnorm_norm_mean += dnorm_i * norm_bti;
        }
        dnorm_mean = blockReduceSum<T>(dnorm_mean);
        dnorm_norm_mean = blockReduceSum<T>(dnorm_norm_mean);

        dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
        dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);

        for(int i=tx; i<hidden_dim; i+=blockDim.x){
            smem[tx] = doutput_offset[i];
            __syncthreads();
            T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
            T dnorm_i = weight[i] * smem[tx];

            // gradient to bias
            atomicAdd(&(dbias[i]), smem[tx]);

            // gradient to weight
            atomicAdd(&(dweight[i]), norm_bti * smem[tx]);

            // gradient to input
            T dval = 0.0f;
            dval += dnorm_i;
            dval -= dnorm_mean;
            dval -= norm_bti * dnorm_norm_mean;
            dval *= rstd_val;
            dinput_offset[i] += dval;
        }
    }
}
 	dim3 block(1024, 1);
    dim3 grid(1, batch * seq_len);
    size_t smem_size = sizeof(T) * block.x;
    util::print_cuda_cfg(grid, block);
    layernorm_backward_kernel4<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, 
                                input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.5 layernorm_bwd其他优化方法

v4版本的性能瓶颈是对dbias和dweight进行atomicAdd计算,这样对于dbias和dweight每一个内存位置都有batch * seq_len个线程串行的进行累加计算,是较为耗时的操作。因此可以让block(1024, 1)计算多行,先将每个block负责计算行的smem[tx]和norm_bti × smem[tx]结果累加到寄存器中,然后再将多个block存在寄存器中的值进行atomicAdd计算,这样可以减少需要执行atomicAdd线程的数量,减少串行执行操作,从而提升性能。

3. layernorm_bwd 不同版本性能对比

数据类型及规模: FP32 16 64 2048
硬件平台:A100-SXM

layernorm_bwd version cycle
layernorm_bwd 7482424
layernorm_bwd 251740
layernorm_bwd 253976
layernorm_bwd 98369

参考链接

序号 链接 备注
1 https://zhuanlan.zhihu.com/p/694974164 layernorm cuda 代码实现
2 https://www.jianshu.com/p/db89d62e1974 layernorm 反向推导公式

网站公告

今日签到

点亮在社区的每一天
去签到