一、model.py中的RMSNorm源码
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
二、RMSNorm原理
归一化(Normalization)通常指的是将数据按比例缩放,使之落入一个小的特定区间,如0到1。这
个过程通常用于在不同特征或数据点之间建立一致性,以便它们可以在相同的尺度上比较或处理。在深度学习中,归一化有助于加快训练速度,提高模型性能,因为它确保了不同特征在训练过程中具
有相似的分布。RMSNorm的基本思想是对网络层的激活输出进行归一化,以使它们具有统一的规模
(scale),这样做可以加速训练过程并提高模型的稳定性。
R M S N o r m ( x ) = x 1 d ∑ i = 1 d x i 2 + ϵ RMSNorm(x) = \frac{x}{\sqrt{\frac{1}{d} \sum^{d}_{i=1} x_i^2 + \epsilon}} RMSNorm(x)=d1∑i=1dxi2+ϵx
- x x x 是网络层的原始输出向量
- d d d 是输出向量的维度
- x i x_i xi 是输出向量中的第 i i i 个元素
- ϵ \epsilon ϵ 是一个很小的常数,用来防止除以 0 0 0 ,通常是 1 0 − 7 10^{-7} 10−7 这样的小数,增加数值稳定性
在某些深层网络和序列模型中效果显著,但 R M S N o r m RMSNorm RMSNorm 可能不适用于任何类型的网络
三、源码注释
class RMSNorm(torch.nn.Module):
def __i
nit__(self, dim: int, eps: float = 1e-6):
# 初始化:dim——维度 eps——epsilon
super().__init__()
self.eps = eps
# weight,一个可以学习的权重,初始值1,维度与输出向量相同
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# rsqrt——开方分之一
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# type_as——确保归一化的结果和输入x有相同的数据类型
output = self._norm(x.float()).type_as(x)
# 将归一化的输出乘以权重参数,得到最终的输出
return output * self.weight
四、举例说明
构造输入x
x = [[1, 2], [5, 6]] x_tensor = torch.tensor(x, dtype = torch.float) x_tensor
tensor([[1., 2.],
[5., 6.]])
对输入数据求平方
x_square = x_tensor.pow(2) x_square
tensor([[ 1., 4.],
[25., 36.]])
沿着最后一个维度计算平均值
1 d ∑ i = 1 d x i 2 \frac{1}{d} \sum^{d}_{i=1} x_i^2 d1i=1∑dxi2
- d d d 是维度
对于每一个样本来说,先求出每一个特征的平方,在计算样本平方的均值
x_square_mean = x_square.mean(-1, keepdim=True) x_square_mean
tensor([[ 2.5000],
[30.5000]])
计算均方根的倒数
eps = 1e-6 rsqrt = 1.0 / torch.sqrt(x_square_mean + eps) rsqrt
tensor([[0.6325],
[0.1811]])
输入数据,得到归一化的结果
normalized_x = x_tensor * rsqrt normalized_x
tensor([[0.6325, 1.2649],
[0.9054, 1.0864]])
假设weight = [2, 3],那么最后的输出将是
weight = torch.tensor([2,3], dtype=torch.float) output = normalized_x * weight output
tensor([[1.2649, 3.7947],
[1.8107, 3.2593]])