PyTorch中Batch Normalization1d的实现与手动验证
一、介绍
Batch Normalization(批归一化)是深度学习中常用的技术,用于加速训练并减少对初始化的敏感性。本文将通过PyTorch内置函数和手动实现两种方式,展示如何对三维输入张量(batch_size, seq_len, embedding_dim)进行批归一化,并验证两者的等价性。 想节省时间的读者直接看下图, 以自然语言处理任务为例。假设输入的维度是(bs, seq_len, embedding)。那么pytorch中的batchnorm1d会对淡蓝色的矩阵做归一化,最后得到embedding长度的均值的方差。接下来进行编程验证。
二、PyTorch内置实现
1. 输入维度调整
PyTorch的nn.BatchNorm1d
要求输入维度为 (batch_size, num_features, ...)
,因此需要将原始输入的维度 (batch_size, seq_len, embedding_dim)
转置为 (batch_size, embedding_dim, seq_len)
。
import torch
import torch.nn as nn
batch_size = 8
seq_len = 10
embedding_dim = 32
# 创建输入张量
x = torch.randn(batch_size, seq_len, embedding_dim)
# 转置输入以适应BatchNorm1d
x_pytorch = x.transpose(1, 2) # shape变为 (batch_size, embedding_dim, seq_len)
2. 使用nn.BatchNorm1d
初始化BatchNorm1d
层,其参数num_features
设置为embedding_dim
。
bn_pytorch = nn.BatchNorm1d(embedding_dim)
# 前向传播
out_pytorch = bn_pytorch(x_pytorch)
out_pytorch = out_pytorch.transpose(1, 2) # 转换回原始维度
三、手动实现Batch Normalization
1. 计算均值和方差
手动实现需沿着batch
和seq_len
维度(前两个维度)计算均值和方差:
def manual_batchnorm(x, gamma, beta, eps=1e-5):
# 计算均值和方差(沿着batch和seq_len维度)
mean = torch.mean(x, dim=(0, 1), keepdim=True)
var = torch.var(x, dim=(0, 1), keepdim=True, unbiased=False) # 使用分母为n
# 标准化
x_normalized = (x - mean) / torch.sqrt(var + eps)
# 应用缩放和平移参数
return gamma * x_normalized + beta
2. 获取PyTorch的参数
为确保与PyTorch实现一致,需获取其gamma
(缩放参数)和beta
(偏移参数),并调整形状:
gamma = bn_pytorch.weight.view(1, 1, embedding_dim) # 形状为 (1, 1, embedding_dim)
beta = bn_pytorch.bias.view(1, 1, embedding_dim)
3. 手动前向传播
直接使用原始输入张量:
out_manual = manual_batchnorm(x, gamma, beta)
四、验证一致性
通过比较PyTorch和手动实现的输出结果,验证两者是否等价:
print("是否相同:", torch.allclose(out_pytorch, out_manual))
输出结果:
是否相同: True
五、关键点解析
维度调整:
- PyTorch的
BatchNorm1d
要求特征维度在第二位,因此需转置输入。 - 手动实现无需转置,直接沿前两个维度计算。
- PyTorch的
方差计算:
- PyTorch的
var
默认使用无偏估计(分母为n-1
),但BatchNorm1d
使用分母为n
,因此需设置unbiased=False
。
- PyTorch的
参数一致性:
gamma
和beta
需与PyTorch层的参数一致,通过调整形状确保广播正确。
六、完整代码
import torch
import torch.nn as nn
torch.manual_seed(42)
batch_size = 8
seq_len = 10
embedding_dim = 32
# 创建输入张量
x = torch.randn(batch_size, seq_len, embedding_dim)
# 使用PyTorch的BatchNorm1d
bn_pytorch = nn.BatchNorm1d(embedding_dim)
x_pytorch = x.transpose(1, 2) # 转置为 (batch_size, embedding_dim, seq_len)
out_pytorch = bn_pytorch(x_pytorch)
out_pytorch = out_pytorch.transpose(1, 2) # 转换回原始维度
# 手动实现
def manual_batchnorm(x, gamma, beta, eps=1e-5):
mean = torch.mean(x, dim=(0,1), keepdim=True)
var = torch.var(x, dim=(0,1), keepdim=True, unbiased=False)
x_normalized = (x - mean) / torch.sqrt(var + eps)
return gamma * x_normalized + beta
# 获取PyTorch的参数
gamma = bn_pytorch.weight.view(1,1,embedding_dim)
beta = bn_pytorch.bias.view(1,1,embedding_dim)
out_manual = manual_batchnorm(x, gamma, beta)
# 验证结果
print("是否相同:", torch.allclose(out_pytorch, out_manual))
七、总结
通过PyTorch内置函数和手动实现的对比,我们验证了两者在批归一化计算上的等价性。关键点在于维度调整、方差计算方式以及参数的正确应用。这种验证方法有助于理解批归一化的内部机制,同时确保手动实现的正确性。@TOC