大模型学习笔记------Llama 3模型架构之RMS Norm与激活函数SwiGLU
上文简单介绍了 Llama 3模型架构。在以后的文章中将逐步学习并记录Llama 3模型中的各个部分。本文将首先介绍归一化模块RMS Norm与激活函数SwiGLU。
1、归一化模块RMS Norm
归一化模块是各个网络结构中必有得模块之一。Llama 3模型基于Transformer,Transformer中采用的归一化模块通常为层归一化Layer Norm(LN),如下图所示。而Llama模型采用LN的改进版RMS Norm。
其中,N为batch size的大小,C为特征图通道数,H为特征图高,W为特征图宽。
LN层的计算为:
实质上,LN是针对layer维度进行标准化,在C,H,W上进行归一化,也就是与batch无关,执行完有B个均值,B个方差。每个样本共用相同的均值和方差。
RMS Norm是一种简化版的层归一化,它是通过均方根(Root Mean Square)计算得到归一化尺度,不需要计算均值和标准差。具体如下:
与层归一化相比,RMSNorm不需要计算均值和标准差,减少了计算复杂度,提高了计算的效率,同时保持了良好的归一化效果。
2、激活函数SwiGLU
Llama 3采用的是SwiGLU(Swish-Gated Linear Unit)激活函数,SwiGLU激活函数是一种结合Swish激活函数与GLU(Gated Linear Unit)机制的激活函数。
2.1 常用激活函数
其实,Swish激活函数与GLU激活函数都不是最常见的激活函数,在transformer结构中常用的激活函数是ReLU(Rectified Linear Unit)、GELU(Gaussian Error Linear Unit)等。在GPT3中采用的就是GELU激活函数,他们的形式如下图所示:
2.2 Swish激活函数
Swish激活函数是一种平滑且连续的激活函数,在Transformer等模型中应用广泛。具体计算公式如下所示:
Swish(x) = x * sigmoid(βx)
其中 sigmoid(x) = 1 / (1 + exp(-x)), β 是一个可学习的参数或一个常数。
通常 β 设置为 1,即 Swish(x) = x * sigmoid(x),此时的激活函数通常称为SiLU,有没有熟悉的感觉。当 β = 0 时,Swish 变为线性函数 x/2。 当 β 趋近于无穷大时,Swish 接近 ReLU。 因此,Swish 可以看作是 ReLU 和线性函数之间的插值。Swish与RuLU的对比如下所示:
2.3 GLU(Gated Linear Unit)激活函数
GLU是一种强大的激活函数,它通过引入门控机制来动态地控制信息流。 它在许多任务中都表现出色,尤其是在需要处理复杂依赖关系的场景中。具体计算函数如下:
GLU(x) = (W_a * x + a) ⊗ σ(W_b * x +b)
其中,W_a, W_b 是权重矩阵; a, b 是偏置向量;σ 是 sigmoid 函数。在 GLU 中,输入特征被分成两个部分:
- 线性部分: (W_a * x + a)
- 门控部分:σ(W_b * x +b)
具体表现形式如下所示:
2.4 SwiGLU激活函数
SwiGLU(Swish-Gated Linear Unit)是一种结合Swish激活函数与GLU(Gated Linear Unit)机制的激活函数。具体计算如下所示:
SwiGLU是对输入进行两次线性变换,将其中一个结果通过Swish激活函数,将两个结果逐元素相乘,形成最终输出。
SwiGLU在处理复杂的文本数据时表现出更好的泛化能力。由于引入了Swish激活函数,其平滑的曲线特性有助于稳定梯度,减少梯度消失和爆炸的可能性。
3、一些思考
1、RMS Norm在模型中的作用仅仅是减少参数量提高效率这么简单吗?
答案当然是否定的,以下是我的一些思考,仅提供参考:
- 中心化不是必须的:LN 通过减去均值来中心化数据,使得数据分布以 0 为中心。 然而,对于某些任务,这种中心化可能并不是必需的。 神经网络可以通过后续的权重和偏置来学习合适的偏移。
- 更关注方差/幅度:RMSNorm 本质上是对输入的幅度进行归一化。 在很多情况下,幅度信息可能比中心化信息更重要。 例如,在语音识别中,信号的能量 (幅度) 通常是重要的特征。
- 与优化器的结合:RMSNorm 的设计理念与一些优化器 (如 Adam) 相似,它们都关注梯度的幅度。 这种一致性可能有助于提高训练的稳定性。
2、为什么采用Swish与GLU结合,而不是使用GELU与GLU结合?
从我的认知中,可能主要是基于以下两个方面的考虑:
- 效率考量:虽然从函数的表现形式上来看,Swish与GELU基本差不多。但是GELU采用的主要使用高斯函数,当然,可以使用tanh和sqrt等进行相似计算。而Swish采用Sigmoid函数。从计算角度看,Swish的计算效率更高。
- 实验效果:感觉当初开发Llama 3模型的时候,开发者应该考虑过GELU激活函数,但是最终选择 Swish + GLU 可能是因为在 Llama 3 的特定架构、数据集和训练设置下,这种组合取得了更好的实验结果。 深度学习模型的选择很多时候是经验性的,需要大量的实验来验证。