简述
本文面向拥有CUDA知识背景并有快速实现layernorm backward需求的读者,若想详细了解layernorm backward计算原理、优化细节请移步参考链接中的文章,本文更侧重于代码实现。如有高见请不吝赐教,谢谢!
很多大佬已经对layernorm_bwd原理、优化方法有过详细讲解(参考链接),这里不再赘述,只是对layernorm_bwd常用优化方法代码复现。
1. layernorm_bwd算法原理及cpu实现
- layernorm_bwd公式推导:
template<typename T, typename T_ACC>
void layernorm_backward_cpu(T* dinput, T* dweight, T* dbias, T* doutput,
T* input, T* weight, T_ACC* mean, T_ACC* rstd,
const int batch, const int seq_len, const int hidden_dim)
{
for(int b=0; b<batch; b++){
for(int i=0; i<seq_len; i++){
const T* doutput_offset = doutput + b * seq_len * hidden_dim + i * hidden_dim;
T* dinput_offset = dinput + b * seq_len * hidden_dim + i * hidden_dim;
const T* input_offset = input + b * seq_len * hidden_dim + i * hidden_dim;
const T_ACC mean_val = mean[b * seq_len + i];
const T_ACC rstd_val = rstd[b * seq_len + i];
T dnorm_mean = 0.0f;
T dnorm_norm_mean = 0.0f;
for(int j = 0; j<hidden_dim; j++){
T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[j] * doutput_offset[j];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);
for(int j = 0; j<hidden_dim; j++){
T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[j] * doutput_offset[j];
// gradient to bias
dbias[j] += doutput_offset[j];
// gradient to weight
dweight[j] += norm_bti * doutput_offset[j];
// gradient to input
T dval = 0.0f;
dval += dnorm_i;
dval -= dnorm_mean;
dval -= norm_bti * dnorm_norm_mean;
dval *= rstd_val;
dinput_offset[j] += dval;
}
}
}
}
2. layernorm_bwd cuda优化方法及实现
2.1 layernorm_bwd
- 优化方法:v1版本是每个线程计算一行数据,即一共有batch*seq_len个线程,每个线程循环计算hidden_dim个数据;
template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel1(T* dinput, T* dweight, T* dbias, const T* doutput,
T* input, T* weight, T_ACC* mean, T_ACC* rstd,
const int batch, const int seq_len, const int hidden_dim)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx < batch * seq_len){
const T* doutput_offset = doutput + idx * hidden_dim;
T* dinput_offset = dinput + idx * hidden_dim;
const T* input_offset = input + idx * hidden_dim;
const T_ACC mean_val = mean[idx];
const T_ACC rstd_val = rstd[idx];
T dnorm_mean = 0.0f;
T dnorm_norm_mean = 0.0f;
for(int i=0; i<hidden_dim; i++){
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * doutput_offset[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);
for(int i=0; i<hidden_dim; i++){
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * doutput_offset[i];
// gradient to bias
atomicAdd(&(dbias[i]), doutput_offset[i]);
// gradient to weight
atomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);
// gradient to input
T dval = 0.0f;
dval += dnorm_i;
dval -= dnorm_mean;
dval -= norm_bti * dnorm_norm_mean;
dval *= rstd_val;
dinput_offset[i] += dval;
}
}
}
dim3 block(256, 1);
dim3 grid((batch * seq_len) / block.x, 1);
util::print_cuda_cfg(grid, block);
layernorm_backward_kernel1<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu,
input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);
2.2 layernorm_fwd_v2
- 优化方法:v2版本是每个warp计算一行数据,即一共有batch*seq_len个warp,每个warp循环计算hidden_dim个数据;warp内部会通过线程束洗牌指令计算出max值。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unroll
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
}
return val;
}
template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel2(T* dinput, T* dweight, T* dbias, const T* doutput,
T* input, T* weight, T_ACC* mean, T_ACC* rstd,
const int batch, const int seq_len, const int hidden_dim)
{
int tx = threadIdx.x;
int by = blockIdx.y;
if(by < batch * seq_len){
const T* doutput_offset = doutput + by * hidden_dim;
T* dinput_offset = dinput + by * hidden_dim;
const T* input_offset = input + by * hidden_dim;
const T_ACC mean_val = mean[by];
const T_ACC rstd_val = rstd[by];
T dnorm_mean = 0.0f;
T dnorm_norm_mean = 0.0f;
for(int i=tx; i<hidden_dim; i+=blockDim.x){
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * doutput_offset[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = warpReduceSum<T>(dnorm_mean);
dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);
dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);
for(int i=tx; i<hidden_dim; i+=blockDim.x){
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * doutput_offset[i];
// gradient to bias
atomicAdd(&(dbias[i]), doutput_offset[i]);
// gradient to weight
atomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);
// gradient to input
T dval = 0.0f;
dval += dnorm_i;
dval -= dnorm_mean;
dval -= norm_bti * dnorm_norm_mean;
dval *= rstd_val;
dinput_offset[i] += dval;
}
}
}
dim3 block(32, 1);
dim3 grid(1, batch * seq_len);
layernorm_backward_kernel2<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu,
input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);
2.3 layernorm_bwd_v3
- 优化方法:基于v2版本仍采用32个线程计算一行数据,但在此版本中将doutput加载至smem中,避免对global memory多次访问。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unroll
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
}
return val;
}
template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel3(T* dinput, T* dweight, T* dbias, const T* doutput,
T* input, T* weight, T_ACC* mean, T_ACC* rstd,
const int batch, const int seq_len, const int hidden_dim)
{
int tx = threadIdx.x;
int by = blockIdx.y;
extern __shared__ unsigned char tmp_smem[];
T *smem = reinterpret_cast<T *>(tmp_smem);
if(by < batch * seq_len){
const T* doutput_offset = doutput + by * hidden_dim;
T* dinput_offset = dinput + by * hidden_dim;
const T* input_offset = input + by * hidden_dim;
const T_ACC mean_val = mean[by];
const T_ACC rstd_val = rstd[by];
T dnorm_mean = 0.0f;
T dnorm_norm_mean = 0.0f;
for(int i=tx; i<hidden_dim; i+=blockDim.x){
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * doutput_offset[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = warpReduceSum<T>(dnorm_mean);
dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);
dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);
for(int i=tx; i<hidden_dim; i+=blockDim.x){
smem[tx] = doutput_offset[i];
__syncthreads();
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * smem[tx];
// gradient to bias
atomicAdd(&(dbias[i]), smem[tx]);
// gradient to weight
atomicAdd(&(dweight[i]), norm_bti * smem[tx]);
// gradient to input
T dval = 0.0f;
dval += dnorm_i;
dval -= dnorm_mean;
dval -= norm_bti * dnorm_norm_mean;
dval *= rstd_val;
dinput_offset[i] += dval;
}
}
}
dim3 block(32, 1);
dim3 grid(1, batch * seq_len);
size_t smem_size = sizeof(T) * block.x;
layernorm_backward_kernel3<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu,
input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);
2.4 layernorm_fwd_v4
- 优化方法:基于v3版本,v4版本让1024个线程循环计算一行。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
}
return val;
}
template<typename T>
__device__ __inline__ T blockReduceSum(T val){
__shared__ T shared[WARP_SIZE];
__shared__ T ret;
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
val = warpReduceSum(val);
if(lane_id == 0){
shared[warp_id] = val;
}
__syncthreads();
val = (threadIdx.x < WARP_SIZE) ? shared[threadIdx.x] : (T)(0.0f);
val = warpReduceSum(val);
if (threadIdx.x == 0)
{
ret = val;
}
__syncthreads();
return ret;
}
template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel4(T* dinput, T* dweight, T* dbias, const T* doutput,
T* input, T* weight, T_ACC* mean, T_ACC* rstd,
const int batch, const int seq_len, const int hidden_dim)
{
int tx = threadIdx.x;
int by = blockIdx.y;
extern __shared__ unsigned char tmp_smem[];
T *smem = reinterpret_cast<T *>(tmp_smem);
if(by < batch * seq_len){
const T* doutput_offset = doutput + by * hidden_dim;
T* dinput_offset = dinput + by * hidden_dim;
const T* input_offset = input + by * hidden_dim;
const T_ACC mean_val = mean[by];
const T_ACC rstd_val = rstd[by];
T dnorm_mean = 0.0f;
T dnorm_norm_mean = 0.0f;
for(int i=tx; i<hidden_dim; i+=blockDim.x){
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * doutput_offset[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = blockReduceSum<T>(dnorm_mean);
dnorm_norm_mean = blockReduceSum<T>(dnorm_norm_mean);
dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);
dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);
for(int i=tx; i<hidden_dim; i+=blockDim.x){
smem[tx] = doutput_offset[i];
__syncthreads();
T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);
T dnorm_i = weight[i] * smem[tx];
// gradient to bias
atomicAdd(&(dbias[i]), smem[tx]);
// gradient to weight
atomicAdd(&(dweight[i]), norm_bti * smem[tx]);
// gradient to input
T dval = 0.0f;
dval += dnorm_i;
dval -= dnorm_mean;
dval -= norm_bti * dnorm_norm_mean;
dval *= rstd_val;
dinput_offset[i] += dval;
}
}
}
dim3 block(1024, 1);
dim3 grid(1, batch * seq_len);
size_t smem_size = sizeof(T) * block.x;
util::print_cuda_cfg(grid, block);
layernorm_backward_kernel4<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu,
input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);
2.5 layernorm_bwd其他优化方法
v4版本的性能瓶颈是对dbias和dweight进行atomicAdd计算,这样对于dbias和dweight每一个内存位置都有batch * seq_len个线程串行的进行累加计算,是较为耗时的操作。因此可以让block(1024, 1)计算多行,先将每个block负责计算行的smem[tx]和norm_bti × smem[tx]结果累加到寄存器中,然后再将多个block存在寄存器中的值进行atomicAdd计算,这样可以减少需要执行atomicAdd线程的数量,减少串行执行操作,从而提升性能。
3. layernorm_bwd 不同版本性能对比
数据类型及规模: FP32 16 64 2048
硬件平台:A100-SXM
layernorm_bwd version | cycle |
---|---|
layernorm_bwd | 7482424 |
layernorm_bwd | 251740 |
layernorm_bwd | 253976 |
layernorm_bwd | 98369 |
参考链接
序号 | 链接 | 备注 |
---|---|---|
1 | https://zhuanlan.zhihu.com/p/694974164 | layernorm cuda 代码实现 |
2 | https://www.jianshu.com/p/db89d62e1974 | layernorm 反向推导公式 |