进阶的 LLM Llama模型教学
Llama 模型在自然语言处理领域有着广泛的应用,它通过自注意力机制能够有效地捕捉序列中的长距离依赖关系。为了更好地理解和实现这个模型,我们先从一些基础的代码和概念入手。
一、库导入
在开始之前,我们需要导入一些必要的 Python 库。这些库将帮助我们完成模型的构建和训练。
import math
import struct
import inspect
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
这些库涵盖了数学运算、数据结构定义、类型提示以及 PyTorch 框架的相关功能,为后续的模型实现提供了强大的支持。
二、实现 ModelArgs 参数类构建
接下来,我们来定义一个参数类 ModelArgs
,用于存储 Transformer 模型的各种超参数。这些超参数将决定模型的结构和行为。
@dataclass
class ModelArgs:
# 自定义超参数
dim: int = 288 # 模型维度
n_layers: int = 6 # Transformer层数
n_heads: int = 6 # 注意力机制的头数
n_kv_heads: Optional[int] = 6 # 键/值头数,如果未指定,则默认为n_heads
vocab_size: int = 32000 # 词汇表大小
hidden_dim: Optional[int] = None # 隐藏层维度,如果未指定,则使用其他规则确定
multiple_of: int = 32 # MLP隐藏层大小是这个数的倍数
norm_eps: float = 1e-5 # 归一化层的epsilon值
max_seq_len: int = 256 # 最大序列长度
dropout: float = 0.0 # 丢弃率
Transformer 模型参数解释
dim
(模型的嵌入维度):这是每个输入词或序列元素的特征维度。它决定了模型对输入数据的表示能力。n_heads
(多头注意力机制中的头数):这个参数决定了嵌入维度如何被拆分以及进行并行计算。多头注意力机制能够让模型从不同的角度学习输入数据的特征。n_layers
(Transformer 的层数):即模型中包含的 Transformer 编码器或解码器的数量。层数越多,模型能够捕捉到的复杂关系就越多,但计算成本也会相应增加。n_kv_heads
(键(Key)和值(Value)的头数):在某些模型(如 LLaMA)中,键和值的头数可以与查询(Query)的头数不同,以减少计算量。这个参数提供了灵活性,使模型能够在保持性能的同时降低计算成本。vocab_size
(词汇表的大小):即模型可以处理的不同词或标记的数量。它决定了模型的输入范围。hidden_dim
(MLP 隐藏层的维度):这是多层感知机(MLP)隐藏层的维度。如果未指定,则会根据其他规则(如模型维度的倍数)动态计算。MLP 是 Transformer 中的一个重要组件,用于对输入数据进行非线性变换。multiple_of
(MLP 隐藏层大小的倍数):MLP 隐藏层大小必须是这个数的倍数。这通常是出于硬件优化的考虑,例如在 GPU 上进行矩阵运算时,某些维度大小为 32 的倍数可以提高计算效率。norm_eps
(归一化层的 epsilon 值):这是归一化层(如 LayerNorm)中的一个小常数,用于防止除零操作。在计算归一化时,它能够确保数值稳定性。max_seq_len
(最大序列长度):即输入序列的最大长度。这个参数限制了模型能够处理的序列长度,对于长文本处理非常重要。dropout
(丢弃率):这是在训练过程中,模型中某些层的输出被随机丢弃的比例。丢弃率可以防止过拟合,并提高模型的泛化能力。
三、实现均方根归一化(RMSNorm,LayerNorm 的一种变体)层
定义与原理
RMSNorm 是 LayerNorm 的一种变体,它通过计算输入向量的均方根(Root Mean Square, RMS)来进行归一化,而省略了计算均值的步骤。这种方法在某些情况下能够提高计算效率和数值稳定性。
RMSNorm 公式
对于输入向量 x = [ x 1 , x 2 , … , x H ] \mathbf{x} = [x_1, x_2, \dots, x_H] x=[x1,x2,…,xH],RMSNorm 的计算步骤如下:
- 计算均方根值(RMS):
RMS ( x ) = 1 H ∑ i = 1 H x i 2 \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{H} \sum_{i=1}^{H} x_i^2} RMS(x)=H1i=1∑Hxi2
其中, H H H是输入向量的维度。
2. 归一化:
x ^ i = x i RMS ( x ) + ϵ \hat{x}_i = \frac{x_i}{\text{RMS}(\mathbf{x}) + \epsilon} x^i=RMS(x)+ϵxi
其中, ϵ \epsilon ϵ是一个极小的常数(如 1 0 − 8 10^{-8} 10−8),用于防止分母为零。
3. 缩放(可选):
y i = γ i ⋅ x ^ i y_i = \gamma_i \cdot \hat{x}_i yi=γi⋅x^i
其中, γ \gamma γ是可学习的缩放参数,与输入向量同维度。
将上述步骤综合起来,RMSNorm 的完整公式为:
RMSNorm ( x ) = γ ⊙ x 1 H ∑ i = 1 H x i 2 + ϵ \text{RMSNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}}{\sqrt{\frac{1}{H} \sum_{i=1}^{H} x_i^2 + \epsilon}} RMSNorm(x)=γ⊙H1∑i=1Hxi2+ϵx
其中, ⊙ \odot ⊙表示逐元素乘法。
与 LayerNorm 的对比
为了更好地理解 RMSNorm,我们来看一下它与 LayerNorm 的区别:
- LayerNorm 公式:
x ^ i = x i − μ σ 2 + ϵ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} x^i=σ2+ϵxi−μ
其中, μ \mu μ是均值, σ 2 \sigma^2 σ2是方差。
- RMSNorm 公式:
x ^ i = x i 1 H ∑ i = 1 H x i 2 + ϵ \hat{x}_i = \frac{x_i}{\sqrt{\frac{1}{H} \sum_{i=1}^{H} x_i^2 + \epsilon}} x^i=H1∑i=1Hxi2+ϵxi
RMSNorm 省略了均值的计算,仅使用均方根值进行归一化。
RMSNorm 的优点
RMSNorm 与 LayerNorm 相比,具有以下优势:
- 计算效率更高:RMSNorm 省略了计算均值的步骤,仅需计算平方均值,减少了约 15% 的计算量。
- 数值稳定性更好:由于不涉及均值计算,RMSNorm 在某些情况下可以避免均值归一化导致的梯度消失问题。
- 适用于 Transformer 架构:在 Transformer 等对计算效率敏感的场景中,RMSNorm 可以显著加速训练。
RMSNorm 的实现
接下来,我们来看看如何用 Python 实现 RMSNorm。
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
# eps 是为了防止除以 0 的情况
self.eps = eps
# weight 是一个可学习的参数,全部初始化为 1
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算 RMSNorm 的核心部分
# x.pow(2).mean(-1, keepdim=True) 计算了输入 x 的平方的均值
# torch.rsqrt 是平方根的倒数,这样就得到了 RMSNorm 的分母部分,再加上 eps 防止分母为 0
# 最后乘以 x,得到 RMSNorm 的结果
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# forward 函数是模型的前向传播
# 首先将输入 x 转为 float 类型,然后进行 RMSNorm,最后再转回原来的数据类型
# 最后乘以 weight,这是 RMSNorm 的一个可学习的缩放因子
output = self._norm(x.float()).type_as(x)
return output * self.weight
RMSNorm 的关键步骤
参数说明
dim
:输入数据的特征维度。例如,如果输入数据的形状是(batch_size, sequence_length, dim)
,则dim
是最后一个维度的大小。eps
:一个非常小的数值,用于防止分母为零,确保数值稳定性。
__init__
关键操作
self.eps
:存储eps
值,用于后续的归一化计算。self.weight
:定义一个可学习的参数weight
,其初始值为全1。这个参数在归一化后对输出进行缩放。
_norm
关键操作
计算平方的均值:
x.pow(2).mean(-1, keepdim=True)
x.pow(2)
:计算输入张量x
的每个元素的平方。.mean(-1, keepdim=True)
:沿着最后一个维度(特征维度)计算均值,并保持输出的维度与输入相同。
计算平方根的倒数:
torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
torch.rsqrt
:计算平方根的倒数,即 1 value \frac{1}{\sqrt{\text{value}}} value1。+ self.eps
:在分母中添加eps
,防止分母为零。
归一化:
x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
- 将输入
x
乘以平方根的倒数,完成归一化。
- 将输入
forward
关键操作
类型转换:
x.float()
- 将输入
x
转换为float
类型,以确保计算的数值稳定性。
- 将输入
调用
_norm
方法:self._norm(x.float())
- 调用
_norm
方法对输入x
进行归一化。
- 调用
类型还原:
.type_as(x)
- 将归一化后的结果转换回输入
x
的原始数据类型。
- 将归一化后的结果转换回输入
应用缩放因子:
return output * self.weight
- 将归一化后的结果乘以可学习的缩放因子
self.weight
。
- 将归一化后的结果乘以可学习的缩放因子
举一个张量作为 RMSNorm 的例子
为了让大家更直观地理解 RMSNorm 的计算过程,我们通过一个具体的张量例子来详细说明。
示例张量
假设我们有一个简单的二维张量,形状为 (batch_size, dim)
,其中 batch_size = 2
和 dim = 4
。我们将逐步计算 RMSNorm 的结果。
x = torch.tensor([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0]])
这个张量的形状是 (2, 4)
,表示有 2 个样本,每个样本有 4 个特征。
RMSNorm 的计算过程
1. 初始化 RMSNorm
假设我们初始化 RMSNorm 如下:
rms_norm = RMSNorm(dim=4, eps=1e-6)
这里,dim = 4
表示每个样本有 4 个特征,eps = 1e-6
是一个很小的数值,用于防止分母为零。
2. 计算均方根值(RMS)
对于输入张量 x
,我们首先计算每个样本的均方根值(RMS)。
计算平方:
x_pow_2 = x.pow(2)
结果为:
tensor([[ 1., 4., 9., 16.],
[25., 36., 49., 64.]])
计算平方的均值:
mean_pow_2 = x_pow_2.mean(-1, keepdim=True)
结果为:
tensor([[ 7.5000],
[41.0000]])
这里,mean(-1, keepdim=True)
表示沿着最后一个维度(特征维度)计算均值,并保持输出的维度与输入相同。
计算平方根的倒数:
inv_rms = torch.rsqrt(mean_pow_2 + rms_norm.eps)
假设 eps = 1e-6
,则:
tensor([[0.377964],
[0.156173]])
这里,torch.rsqrt
计算平方根的倒数,即 1 value \frac{1}{\sqrt{\text{value}}} value1。
3. 归一化
将输入张量 x
乘以平方根的倒数:
normalized_x = x * inv_rms
结果为:
tensor([[0.377964, 0.755929, 1.133893, 1.511858],
[0.780869, 0.937043, 1.093217, 1.249391]])
4. 应用缩放因子
假设 self.weight
初始化为全1:
rms_norm.weight = nn.Parameter(torch.ones(4))
则最终的输出为:
output = normalized_x * rms_norm.weight
结果为:
tensor([[0.377964, 0.755929, 1.133893, 1.511858],
[0.780869, 0.937043, 1.093217, 1.249391]])
通过这个例子,大家应该能够清楚地理解 RMSNorm 的计算过程。RMSNorm 在 Transformer 模型中有着重要的应用,它能够提高模型的训练效率和数值稳定性。希望你们能够掌握这个知识点,并在后续的学习中灵活运用。
四、旋转位置矩阵函数实现
# 获得旋转嵌入的实部和虚部
# 注意:此处的dim应为 dim//n_head,因为我们是对每个head进行旋转嵌入
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始,步长为2的序列,长度为dim的一半
# 然后每个元素除以dim,再取theta的倒数,得到频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成一个从0到end的序列,长度为end
t = torch.arange(end, device=freqs.device)
# 计算外积,得到一个二维矩阵,每一行是t的元素乘以freqs的元素
freqs = torch.outer(t, freqs).float()
# 计算频率的余弦值,得到实部
freqs_cos = torch.cos(freqs)
# 计算频率的正弦值,得到虚部
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
这段代码是用于计算旋转位置嵌入(RoPE,Rotary Position Embedding)的实部和虚部。旋转位置嵌入是一种用于处理序列数据(如自然语言处理中的文本序列)的编码方式,它通过将位置信息以旋转的形式嵌入到特征向量中,使得模型能够更好地捕捉序列中的相对位置信息。以下结合旋转位置矩阵公式来详细讲解这段代码:
旋转位置嵌入的基本概念
旋转位置嵌入的核心思想是将位置信息通过旋转矩阵的方式嵌入到特征向量中。对于一个维度为 d d d的特征向量,旋转位置嵌入将特征向量分成实部和虚部,分别对应余弦和正弦函数。具体来说,对于位置 i i i和特征维度 j j j,旋转嵌入的计算公式如下:
Re ( i , j ) = cos ( i θ j / d ) \text{Re}(i, j) = \cos\left(\frac{i}{\theta^{j/d}}\right) Re(i,j)=cos(θj/di)
Im ( i , j ) = sin ( i θ j / d ) \text{Im}(i, j) = \sin\left(\frac{i}{\theta^{j/d}}\right) Im(i,j)=sin(θj/di)
其中:
- i i i是位置索引,表示序列中的位置。
- j j j是特征维度索引。
- d d d是特征向量的维度。
- θ \theta θ是一个常数,通常取 10000 10000 10000,用于控制频率的缩放。
代码解析
这段代码的目的是根据上述公式预计算旋转位置嵌入的实部和虚部。我们逐步解析代码的每一部分:
1. 计算频率 freqs \text{freqs} freqs
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
torch.arange(0, dim, 2)
生成一个从 0 开始,步长为 2 的序列,长度为 d i m 2 \frac{dim}{2} 2dim。这是因为旋转位置嵌入中,每个维度的实部和虚部是交替计算的,所以只需要计算一半的维度。[: (dim // 2)]
确保取到的序列长度为 d i m 2 \frac{dim}{2} 2dim。torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
将每个元素除以 d d d,得到归一化后的维度索引。theta ** (...)
计算 θ \theta θ的幂次方,得到频率的缩放因子。1.0 / (...)
取倒数,得到最终的频率 freqs \text{freqs} freqs。
2. 生成时间序列 t t t
t = torch.arange(end, device=freqs.device)
torch.arange(end)
生成一个从 0 到 end − 1 \text{end} - 1 end−1的序列,表示序列中的位置索引 i i i。device=freqs.device
确保生成的张量与freqs
在同一设备(CPU 或 GPU)上。
3. 计算外积
freqs = torch.outer(t, freqs).float()
torch.outer(t, freqs)
计算 t t t和 freqs \text{freqs} freqs的外积,得到一个二维矩阵。矩阵的每个元素是 t [ i ] × freqs [ j ] t[i] \times \text{freqs}[j] t[i]×freqs[j],即位置 i i i和维度 j j j的频率乘积。
4. 计算实部和虚部
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
torch.cos(freqs)
计算频率的余弦值,得到旋转位置嵌入的实部。torch.sin(freqs)
计算频率的正弦值,得到旋转位置嵌入的虚部。
输出
最终,这段代码返回两个张量:
freqs_cos
:旋转位置嵌入的实部。freqs_sin
:旋转位置嵌入的虚部。
这两个张量可以用于后续的旋转操作,将位置信息嵌入到特征向量中。
在旋转位置嵌入(RoPE)中,实部和虚部的使用方式是将它们与嵌入向量结合,通过旋转操作来为模型提供位置信息。具体来说,实部和虚部会与嵌入向量的对应维度进行逐元素的乘法操作,从而实现位置信息的编码。以下是详细的步骤和解释:
位置编码中需要注意的点
1. 嵌入向量的拆分
假设我们有一个嵌入向量 E \mathbf{E} E,其维度为 dim \text{dim} dim。为了与旋转位置嵌入结合,我们需要将嵌入向量拆分为实部和虚部。通常,嵌入向量的偶数维度被视为实部,奇数维度被视为虚部。
例如,对于一个维度为 8 的嵌入向量 E \mathbf{E} E:
E = [ e 0 , e 1 , e 2 , e 3 , e 4 , e 5 , e 6 , e 7 ] \mathbf{E} = [e_0, e_1, e_2, e_3, e_4, e_5, e_6, e_7] E=[e0,e1,e2,e3,e4,e5,e6,e7]
可以拆分为:
- 实部: E real = [ e 0 , e 2 , e 4 , e 6 ] \mathbf{E}_{\text{real}} = [e_0, e_2, e_4, e_6] Ereal=[e0,e2,e4,e6]
- 虚部: E imag = [ e 1 , e 3 , e 5 , e 7 ] \mathbf{E}_{\text{imag}} = [e_1, e_3, e_5, e_7] Eimag=[e1,e3,e5,e7]
2. 旋转操作
旋转位置嵌入的核心是通过旋转矩阵将位置信息编码到嵌入向量中。旋转矩阵由实部和虚部构成,具体形式如下:
RoPE ( p o s ) = [ cos ( p o s ⋅ θ ) − sin ( p o s ⋅ θ ) sin ( p o s ⋅ θ ) cos ( p o s ⋅ θ ) ] \text{RoPE}(pos) = \begin{bmatrix} \cos(pos \cdot \theta) & -\sin(pos \cdot \theta) \\ \sin(pos \cdot \theta) & \cos(pos \cdot \theta) \end{bmatrix} RoPE(pos)=[cos(pos⋅θ)sin(pos⋅θ)−sin(pos⋅θ)cos(pos⋅θ)]
其中, p o s pos pos是位置索引, θ \theta θ是频率参数。
对于每个位置 p o s pos pos,我们有对应的实部 cos ( p o s ⋅ θ ) \cos(pos \cdot \theta) cos(pos⋅θ)和虚部 sin ( p o s ⋅ θ ) \sin(pos \cdot \theta) sin(pos⋅θ)。这些值在前面的代码中已经计算好了,分别存储在 freqs_cos
和 freqs_sin
中。
3. 应用旋转
将旋转矩阵应用于嵌入向量的实部和虚部。具体操作如下:
- 对于每个位置 p o s pos pos,我们有对应的实部 E real ( p o s ) \mathbf{E}_{\text{real}}(pos) Ereal(pos)和虚部 E imag ( p o s ) \mathbf{E}_{\text{imag}}(pos) Eimag(pos)。
- 使用旋转矩阵对这些值进行旋转:
[ E real ′ ( p o s ) E imag ′ ( p o s ) ] = [ cos ( p o s ⋅ θ ) − sin ( p o s ⋅ θ ) sin ( p o s ⋅ θ ) cos ( p o s ⋅ θ ) ] [ E real ( p o s ) E imag ( p o s ) ] \begin{bmatrix} \mathbf{E}_{\text{real}}'(pos) \\ \mathbf{E}_{\text{imag}}'(pos) \end{bmatrix}= \begin{bmatrix} \cos(pos \cdot \theta) & -\sin(pos \cdot \theta) \\ \sin(pos \cdot \theta) & \cos(pos \cdot \theta) \end{bmatrix} \begin{bmatrix} \mathbf{E}_{\text{real}}(pos) \\ \mathbf{E}_{\text{imag}}(pos) \end{bmatrix} [Ereal′(pos)Eimag′(pos)]=[cos(pos⋅θ)sin(pos⋅θ)−sin(pos⋅θ)cos(pos⋅θ)][Ereal(pos)Eimag(pos)]
具体来说,旋转后的实部和虚部为:
E real ′ ( p o s ) = E real ( p o s ) ⋅ cos ( p o s ⋅ θ ) − E imag ( p o s ) ⋅ sin ( p o s ⋅ θ ) \mathbf{E}_{\text{real}}'(pos) = \mathbf{E}_{\text{real}}(pos) \cdot \cos(pos \cdot \theta) - \mathbf{E}_{\text{imag}}(pos) \cdot \sin(pos \cdot \theta) Ereal′(pos)=Ereal(pos)⋅cos(pos⋅θ)−Eimag(pos)⋅sin(pos⋅θ)
E imag ′ ( p o s ) = E real ( p o s ) ⋅ sin ( p o s ⋅ θ ) + E imag ( p o s ) ⋅ cos ( p o s ⋅ θ ) \mathbf{E}_{\text{imag}}'(pos) = \mathbf{E}_{\text{real}}(pos) \cdot \sin(pos \cdot \theta) + \mathbf{E}_{\text{imag}}(pos) \cdot \cos(pos \cdot \theta) Eimag′(pos)=Ereal(pos)⋅sin(pos⋅θ)+Eimag(pos)⋅cos(pos⋅θ)
4. 合并旋转后的结果
将旋转后的实部和虚部重新组合成完整的嵌入向量。对于每个位置 p o s pos pos,我们有:
E ′ ( p o s ) = [ E real ′ ( p o s ) 0 , E imag ′ ( p o s ) 0 , E real ′ ( p o s ) 1 , E imag ′ ( p o s ) 1 , … ] \mathbf{E}'(pos) = [\mathbf{E}_{\text{real}}'(pos)_0, \mathbf{E}_{\text{imag}}'(pos)_0, \mathbf{E}_{\text{real}}'(pos)_1, \mathbf{E}_{\text{imag}}'(pos)_1, \dots] E′(pos)=[Ereal′(pos)0,Eimag′(pos)0,Ereal′(pos)1,Eimag′(pos)1,…]
具体的例子
我们来举一个具体的数值例子来说明这个过程。
假设我们有以下参数:
dim = 8
(嵌入向量的维度)end = 5
(序列的长度)theta = 10000.0
(频率参数)
步骤1:计算频率
首先,我们计算频率 freqs
:
dim = 8
theta = 10000.0
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
torch.arange(0, dim, 2)
生成序列: 0 , 2 , 4 , 6 0, 2, 4, 6 0,2,4,6[: (dim // 2)]
取前一半: 0 , 2 , 4 , 6 0, 2, 4, 6 0,2,4,6- 每个元素除以
dim
: 0 8 , 2 8 , 4 8 , 6 8 \frac{0}{8}, \frac{2}{8}, \frac{4}{8}, \frac{6}{8} 80,82,84,86= 0 , 0.25 , 0.5 , 0.75 0, 0.25, 0.5, 0.75 0,0.25,0.5,0.75 - 取
theta
的倒数: 1 1000 0 0 , 1 1000 0 0.25 , 1 1000 0 0.5 , 1 1000 0 0.75 \frac{1}{10000^0}, \frac{1}{10000^{0.25}}, \frac{1}{10000^{0.5}}, \frac{1}{10000^{0.75}} 1000001,100000.251,100000.51,100000.751 - 计算得到: 1 , 0.003162 , 0.0001 , 0.00003162 1, 0.003162, 0.0001, 0.00003162 1,0.003162,0.0001,0.00003162
步骤2:生成时间序列
生成时间序列 t
:
end = 5
t = torch.arange(end, device=freqs.device)
t
为: 0 , 1 , 2 , 3 , 4 0, 1, 2, 3, 4 0,1,2,3,4
步骤3:计算外积
计算外积 freqs
:
freqs = torch.outer(t, freqs).float()
- 外积计算:
[ 0 × 1 0 × 0.003162 0 × 0.0001 0 × 0.00003162 1 × 1 1 × 0.003162 1 × 0.0001 1 × 0.00003162 2 × 1 2 × 0.003162 2 × 0.0001 2 × 0.00003162 3 × 1 3 × 0.003162 3 × 0.0001 3 × 0.00003162 4 × 1 4 × 0.003162 4 × 0.0001 4 × 0.00003162 ] = [ 0 0 0 0 1 0.003162 0.0001 0.00003162 2 0.006324 0.0002 0.00006324 3 0.009486 0.0003 0.00009486 4 0.012648 0.0004 0.00012648 ] \begin{bmatrix} 0 \times 1 & 0 \times 0.003162 & 0 \times 0.0001 & 0 \times 0.00003162 \\ 1 \times 1 & 1 \times 0.003162 & 1 \times 0.0001 & 1 \times 0.00003162 \\ 2 \times 1 & 2 \times 0.003162 & 2 \times 0.0001 & 2 \times 0.00003162 \\ 3 \times 1 & 3 \times 0.003162 & 3 \times 0.0001 & 3 \times 0.00003162 \\ 4 \times 1 & 4 \times 0.003162 & 4 \times 0.0001 & 4 \times 0.00003162 \\ \end{bmatrix}= \begin{bmatrix} 0 & 0 & 0 & 0 \\ 1 & 0.003162 & 0.0001 & 0.00003162 \\ 2 & 0.006324 & 0.0002 & 0.00006324 \\ 3 & 0.009486 & 0.0003 & 0.00009486 \\ 4 & 0.012648 & 0.0004 & 0.00012648 \\ \end{bmatrix} 0×11×12×13×14×10×0.0031621×0.0031622×0.0031623×0.0031624×0.0031620×0.00011×0.00012×0.00013×0.00014×0.00010×0.000031621×0.000031622×0.000031623×0.000031624×0.00003162 = 0123400.0031620.0063240.0094860.01264800.00010.00020.00030.000400.000031620.000063240.000094860.00012648
步骤4:计算实部和虚部
计算实部 freqs_cos
和虚部 freqs_sin
:
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
- 实部
freqs_cos
:
[ cos ( 0 ) cos ( 0 ) cos ( 0 ) cos ( 0 ) cos ( 1 ) cos ( 0.003162 ) cos ( 0.0001 ) cos ( 0.00003162 ) cos ( 2 ) cos ( 0.006324 ) cos ( 0.0002 ) cos ( 0.00006324 ) cos ( 3 ) cos ( 0.009486 ) cos ( 0.0003 ) cos ( 0.00009486 ) cos ( 4 ) cos ( 0.012648 ) cos ( 0.0004 ) cos ( 0.00012648 ) ] ≈ [ 1 1 1 1 0.5403 0.999995 0.999999 0.999999 − 0.4161 0.999983 0.999997 0.999998 − 0.989992 0.999969 0.999994 0.999997 − 0.653644 0.999953 0.999991 0.999996 ] \begin{bmatrix} \cos(0) & \cos(0) & \cos(0) & \cos(0) \\ \cos(1) & \cos(0.003162) & \cos(0.0001) & \cos(0.00003162) \\ \cos(2) & \cos(0.006324) & \cos(0.0002) & \cos(0.00006324) \\ \cos(3) & \cos(0.009486) & \cos(0.0003) & \cos(0.00009486) \\ \cos(4) & \cos(0.012648) & \cos(0.0004) & \cos(0.00012648) \\ \end{bmatrix} \approx \begin{bmatrix} 1 & 1 & 1 & 1 \\ 0.5403 & 0.999995 & 0.999999 & 0.999999 \\ -0.4161 & 0.999983 & 0.999997 & 0.999998 \\ -0.989992 & 0.999969 & 0.999994 & 0.999997 \\ -0.653644 & 0.999953 & 0.999991 & 0.999996 \\ \end{bmatrix} cos(0)cos(1)cos(2)cos(3)cos(4)cos(0)cos(0.003162)cos(0.006324)cos(0.009486)cos(0.012648)cos(0)cos(0.0001)cos(0.0002)cos(0.0003)cos(0.0004)cos(0)cos(0.00003162)cos(0.00006324)cos(0.00009486)cos(0.00012648) ≈ 10.5403−0.4161−0.989992−0.65364410.9999950.9999830.9999690.99995310.9999990.9999970.9999940.99999110.9999990.9999980.9999970.999996
- 虚部
freqs_sin
:
[ sin ( 0 ) sin ( 0 ) sin ( 0 ) sin ( 0 ) sin ( 1 ) sin ( 0.003162 ) sin ( 0.0001 ) sin ( 0.00003162 ) sin ( 2 ) sin ( 0.006324 ) sin ( 0.0002 ) sin ( 0.00006324 ) sin ( 3 ) sin ( 0.009486 ) sin ( 0.0003 ) sin ( 0.00009486 ) sin ( 4 ) sin ( 0.012648 ) sin ( 0.0004 ) sin ( 0.00012648 ) ] ≈ [ 0 0 0 0 0.841471 0.003162 0.0001 0.00003162 0.909297 0.006324 0.0002 0.00006324 0.14112 0.009486 0.0003 0.00009486 − 0.756802 0.012648 0.0004 0.00012648 ] \begin{bmatrix} \sin(0) & \sin(0) & \sin(0) & \sin(0) \\ \sin(1) & \sin(0.003162) & \sin(0.0001) & \sin(0.00003162) \\ \sin(2) & \sin(0.006324) & \sin(0.0002) & \sin(0.00006324) \\ \sin(3) & \sin(0.009486) & \sin(0.0003) & \sin(0.00009486) \\ \sin(4) & \sin(0.012648) & \sin(0.0004) & \sin(0.00012648) \\ \end{bmatrix} \approx \begin{bmatrix} 0 & 0 & 0 & 0 \\ 0.841471 & 0.003162 & 0.0001 & 0.00003162 \\ 0.909297 & 0.006324 & 0.0002 & 0.00006324 \\ 0.14112 & 0.009486 & 0.0003 & 0.00009486 \\ -0.756802 & 0.012648 & 0.0004 & 0.00012648 \\ \end{bmatrix} sin(0)sin(1)sin(2)sin(3)sin(4)sin(0)sin(0.003162)sin(0.006324)sin(0.009486)sin(0.012648)sin(0)sin(0.0001)sin(0.0002)sin(0.0003)sin(0.0004)sin(0)sin(0.00003162)sin(0.00006324)sin(0.00009486)sin(0.00012648) ≈ 00.8414710.9092970.14112−0.75680200.0031620.0063240.0094860.01264800.00010.00020.00030.000400.000031620.000063240.000094860.00012648
最终结果
函数返回两个矩阵:
freqs_cos
:表示旋转位置嵌入的实部。freqs_sin
:表示旋转位置嵌入的虚部。
这两个矩阵将用于对嵌入向量进行旋转,以提供位置信息。
五、旋转位置嵌入
实际上,旋转位置嵌入(RoPE)的使用方式与传统的固定位置编码(如 sin
和 cos
形式的编码)有所不同。它并不是直接加到嵌入向量(embedding)中,而是通过一种特殊的旋转操作来影响嵌入向量的表示,特别是在自注意力机制中。
传统位置编码 vs. 旋转位置嵌入
传统位置编码
在传统的Transformer模型中,位置编码是直接加到嵌入向量上的。具体来说:
- 嵌入向量 E \mathbf{E} E的形状为 ( batch_size , seq_len , dim ) (\text{batch\_size}, \text{seq\_len}, \text{dim}) (batch_size,seq_len,dim)。
- 位置编码 P E \mathbf{PE} PE的形状与嵌入向量相同。
- 通过逐元素相加的方式将位置编码融入嵌入向量:
E ′ = E + P E \mathbf{E}' = \mathbf{E} + \mathbf{PE} E′=E+PE
这种方法简单直接,但有一个缺点:当处理的序列长度超过训练时的最大长度时,位置编码可能会失效,因为位置编码是固定的。
旋转位置嵌入
旋转位置嵌入(RoPE)的核心思想是通过旋转操作将位置信息融入嵌入向量中,而不是直接相加。它的主要优点是能够处理任意长度的序列,因为旋转操作是动态的,不依赖于固定的编码。
在自注意力机制中的应用
旋转位置嵌入通常应用于自注意力机制中的查询(query)和键(key)向量,而不是直接应用于嵌入向量。具体步骤如下:
嵌入向量的处理:
- 嵌入向量 E \mathbf{E} E经过线性变换后,生成查询(query)、键(key)和值(value)向量:
Q = E W Q , K = E W K , V = E W V \mathbf{Q} = \mathbf{E} \mathbf{W}_Q, \quad \mathbf{K} = \mathbf{E} \mathbf{W}_K, \quad \mathbf{V} = \mathbf{E} \mathbf{W}_V Q=EWQ,K=EWK,V=EWV
其中, W Q , W K , W V \mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V WQ,WK,WV是可学习的权重矩阵。
- 嵌入向量 E \mathbf{E} E经过线性变换后,生成查询(query)、键(key)和值(value)向量:
应用旋转位置嵌入:
- 在自注意力机制中,查询(query)和键(key)向量会通过旋转位置嵌入进行调整。具体来说,查询和键向量的每个维度会被拆分为实部和虚部,然后通过旋转矩阵进行旋转:
Q ′ = RoPE ( Q , pos ) \mathbf{Q}' = \text{RoPE}(\mathbf{Q}, \text{pos}) Q′=RoPE(Q,pos)
K ′ = RoPE ( K , pos ) \mathbf{K}' = \text{RoPE}(\mathbf{K}, \text{pos}) K′=RoPE(K,pos)
其中, RoPE \text{RoPE} RoPE是旋转位置嵌入操作, pos \text{pos} pos是位置索引。
- 在自注意力机制中,查询(query)和键(key)向量会通过旋转位置嵌入进行调整。具体来说,查询和键向量的每个维度会被拆分为实部和虚部,然后通过旋转矩阵进行旋转:
自注意力计算:
- 旋转后的查询和键向量用于计算注意力分数:
Attention ( Q ′ , K ′ , V ) = softmax ( Q ′ K ′ T dim ) V \text{Attention}(\mathbf{Q}', \mathbf{K}', \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}' \mathbf{K}'^T}{\sqrt{\text{dim}}}\right) \mathbf{V} Attention(Q′,K′,V)=softmax(dimQ′K′T)V
- 旋转后的查询和键向量用于计算注意力分数:
为什么在自注意力中而不是嵌入层中
旋转位置嵌入在自注意力机制中应用的原因主要有以下几点:
- 动态性:旋转位置嵌入通过旋转操作动态地融入位置信息,能够处理任意长度的序列。
- 相对位置信息:旋转位置嵌入编码的是相对位置信息,而不是绝对位置。这使得模型能够更好地捕捉长文本中的上下文关系。
- 灵活性:旋转操作可以灵活地应用于查询和键向量,而不需要改变嵌入层的结构。
ROPE函数的实现
# 此函数的作用是将freqs_cis调整为与x的形状相同,以便能够与x进行广播操作
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# 获取x的维度数
ndim = x.ndim
# 断言,确保1在x的维度范围内
assert 0 <= 1 < ndim
# 断言,确保freqs_cis的形状与x的第二维和最后一维相同
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# 将freqs_cis调整为新的形状,并返回
return freqs_cis.view(shape)
#这个函数的作用是将旋转位置嵌入(RoPE)应用于查询(query)和键(key)张量。
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# 重新塑形频率张量以进行广播
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# 应用旋转,分别计算旋转后的实部和虚部
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# 将最后两个维度合并,并还原为原始张量的形状
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
这个函数的作用是将键值对(key-value pairs)进行扩展,以增加键值对的数量
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
bs, slen, n_kv_heads, head_dim = x.shape
# 如果重复次数为1,则不需要重复,直接返回原始张量
if n_rep == 1:
return x
# 对张量进行扩展和重塑操作以重复键值对
return (
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
)
这段代码包含了三个函数,分别是 reshape_for_broadcast
、apply_rotary_emb
和 repeat_kv
。这些函数在Transformer模型中用于处理旋转位置嵌入(RoPE)和键值对的扩展操作。下面我将逐一解析这些函数的作用和实现细节。
1. reshape_for_broadcast
函数
这个函数的作用是将频率张量 freqs_cis
调整为与张量 x
的形状相同,以便能够与 x
进行广播操作。
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# 获取x的维度数
ndim = x.ndim
# 断言,确保1在x的维度范围内
assert 0 <= 1 < ndim
# 断言,确保freqs_cis的形状与x的第二维和最后一维相同
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# 将freqs_cis调整为新的形状,并返回
return freqs_cis.view(shape)
广播操作示例
假设我们有以下张量:
freqs_cis
:形状为(3, 5)
,值为:
[ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ] \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 6 & 7 & 8 & 9 & 10 \\ 11 & 12 & 13 & 14 & 15 \end{bmatrix} 161127123813491451015
x
:形状为(2, 3, 4, 5)
,值为:
[ [ [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 ] [ 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 ] [ 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 ] ] [ [ 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 ] [ 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 ] [ 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 ] ] ] \begin{bmatrix} \begin{bmatrix} \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 6 & 7 & 8 & 9 & 10 \\ 11 & 12 & 13 & 14 & 15 \\ 16 & 17 & 18 & 19 & 20 \end{bmatrix} \\ \begin{bmatrix} 21 & 22 & 23 & 24 & 25 \\ 26 & 27 & 28 & 29 & 30 \\ 31 & 32 & 33 & 34 & 35 \\ 36 & 37 & 38 & 39 & 40 \end{bmatrix} \\ \begin{bmatrix} 41 & 42 & 43 & 44 & 45 \\ 46 & 47 & 48 & 49 & 50 \\ 51 & 52 & 53 & 54 & 55 \\ 56 & 57 & 58 & 59 & 60 \end{bmatrix} \end{bmatrix} \\ \begin{bmatrix} \begin{bmatrix} 61 & 62 & 63 & 64 & 65 \\ 66 & 67 & 68 & 69 & 70 \\ 71 & 72 & 73 & 74 & 75 \\ 76 & 77 & 78 & 79 & 80 \end{bmatrix} \\ \begin{bmatrix} 81 & 82 & 83 & 84 & 85 \\ 86 & 87 & 88 & 89 & 90 \\ 91 & 92 & 93 & 94 & 95 \\ 96 & 97 & 98 & 99 & 100 \end{bmatrix} \\ \begin{bmatrix} 101 & 102 & 103 & 104 & 105 \\ 106 & 107 & 108 & 109 & 110 \\ 111 & 112 & 113 & 114 & 115 \\ 116 & 117 & 118 & 119 & 120 \end{bmatrix} \end{bmatrix} \end{bmatrix} 1611162712173813184914195101520 2126313622273237232833382429343925303540 4146515642475257434853584449545945505560 6166717662677277636873786469747965707580 81869196828792978388939884899499859095100 101106111116102107112117103108113118104109114119105110115120
调整后的 freqs_cis
形状为 (1, 3, 1, 5)
,值为:
[ [ [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ] ] ] \begin{bmatrix} \begin{bmatrix} \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 6 & 7 & 8 & 9 & 10 \\ 11 & 12 & 13 & 14 & 15 \end{bmatrix} \end{bmatrix} \end{bmatrix}
161127123813491451015
1.1解析
- 输入:
freqs_cis
:频率张量,形状为(seq_len, dim // 2)
。x
:目标张量,形状为(batch_size, seq_len, dim)
。
- 输出:
- 调整后的频率张量,形状为
(1, seq_len, 1, dim // 2)
。
- 调整后的频率张量,形状为
1.2作用
- 通过调整
freqs_cis
的形状,使其能够与x
进行逐元素的广播操作。 - 例如,假设
x
的形状为(batch_size, seq_len, dim)
,则调整后的freqs_cis
形状为(1, seq_len, 1, dim // 2)
,这样可以与x
的形状(batch_size, seq_len, dim // 2, 2)
进行广播。
2. apply_rotary_emb
函数
这个函数的作用是将旋转位置嵌入(RoPE)应用于查询(query)和键(key)张量。
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# 重新塑形频率张量以进行广播
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# 应用旋转,分别计算旋转后的实部和虚部
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# 将最后两个维度合并,并还原为原始张量的形状
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
好的,我们通过一个具体的例子,逐步说明 apply_rotary_emb
函数的每个步骤。为了简化计算,我们假设维度较小。
示例参数
batch_size = 2
seq_len = 3
dim = 4
(嵌入维度)dim // 2 = 2
(因为旋转位置嵌入将维度拆分为实部和虚部)
示例输入
import torch
# 查询张量 xq 和键张量 xk
xq = torch.tensor([
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
[[13.0, 14.0, 15.0, 16.0], [17.0, 18.0, 19.0, 20.0], [21.0, 22.0, 23.0, 24.0]]
])
xk = torch.tensor([
[[25.0, 26.0, 27.0, 28.0], [29.0, 30.0, 31.0, 32.0], [33.0, 34.0, 35.0, 36.0]],
[[37.0, 38.0, 39.0, 40.0], [41.0, 42.0, 43.0, 44.0], [45.0, 46.0, 47.0, 48.0]]
])
# 频率张量 freqs_cos 和 freqs_sin
freqs_cos = torch.tensor([
[1.0, 0.5],
[0.8, 0.9],
[0.7, 0.6]
])
freqs_sin = torch.tensor([
[0.0, 0.5],
[0.6, 0.4],
[0.3, 0.8]
])
步骤1:将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
示例计算
xq
的形状为(2, 3, 4)
,重塑后为(2, 3, 2, 2)
:xq_reshaped = xq.float().reshape(xq.shape[:-1] + (-1, 2))
结果:
tensor([[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]], [[[13., 14.], [15., 16.]], [[17., 18.], [19., 20.]], [[21., 22.], [23., 24.]]]])
分离实部和虚部:
xq_r, xq_i = xq_reshaped.unbind(-1)
结果:
xq_r = tensor([[[1., 3.], [5., 7.], [9., 11.]], [[13., 15.], [17., 19.], [21., 23.]]]) xq_i = tensor([[[2., 4.], [6., 8.], [10., 12.]], [[14., 16.], [18., 20.], [22., 24.]]])
同理,xk
也会被拆分为 xk_r
和 xk_i
。
步骤2:调整频率张量的形状以进行广播
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
2.示例计算
freqs_cos
和freqs_sin
的形状为(3, 2)
,调整后为(1, 3, 1, 2)
:freqs_cos = freqs_cos.view(1, 3, 1, 2) freqs_sin = freqs_sin.view(1, 3, 1, 2)
步骤3:应用旋转操作
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
3.示例计算
计算
xq_out_r
和xq_out_i
:xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
具体计算如下:
# xq_r 和 xq_i 的形状为 (2, 3, 2)
# freqs_cos 和 freqs_sin 的形状为 (1, 3, 1, 2),广播后为 (2, 3, 2)
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
结果:
xq_out_r = tensor([[[1.0, 1.5], [3.2, 3.8], [5.4, 6.2]],
[[12.6, 13.5], [15.2, 16.1], [17.8, 18.7]]])
xq_out_i = tensor([[[0.0, 1.0], [1.8, 2.4], [2.7, 3.6]],
[[7.0, 8.0], [8.2, 9.4], [9.6, 10.8]]])
同理,计算 xk_out_r
和 xk_out_i
。
步骤4:将最后两个维度合并,并还原为原始张量的形状
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
4.示例计算
将实部和虚部重新组合为一个张量:
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
结果:
xq_out = tensor([[[1.0, 0.0, 1.5, 1.0], [3.2, 1.8, 3.8, 2.4], [5.4, 2.7, 6.2, 3.6]], [[12.6, 7.0, 13.5, 8.0], [15.2, 8.2, 16.1, 9.4], [17.8, 9.6, 18.7, 10.8]]])
同理,xk_out
也会被重新组合。
4.最终结果
xq_out
和xk_out
的形状为(2, 3, 4)
,与原始的xq
和xk
形状相同,但已经融入了旋转位置嵌入。
总结
通过 apply_rotary_emb
函数,我们将查询和键张量的每个维度拆分为实部和虚部,然后通过旋转操作将位置信息融入到嵌入向量中。这种方法能够动态地处理任意长度的序列,避免了传统位置编码的固定长度限制。
2.1解析
- 输入:
xq
:查询张量,形状为(batch_size, seq_len, dim)
。xk
:键张量,形状为(batch_size, seq_len, dim)
。freqs_cos
:旋转位置嵌入的实部,形状为(seq_len, dim // 2)
。freqs_sin
:旋转位置嵌入的虚部,形状为(seq_len, dim // 2)
。
- 输出:
xq_out
:应用旋转位置嵌入后的查询张量。xk_out
:应用旋转位置嵌入后的键张量。
2.2作用
- 拆分实部和虚部:
- 将查询和键张量的每个维度拆分为实部和虚部。
- 例如,假设
dim = 8
,则xq
和xk
的形状为(batch_size, seq_len, 8)
。经过拆分后,xq_r
和xq_i
的形状为(batch_size, seq_len, 4)
。
- 调整频率张量的形状:
- 使用
reshape_for_broadcast
函数将freqs_cos
和freqs_sin
调整为与xq_r
和xq_i
的形状一致。
- 使用
- 应用旋转操作:
- 使用旋转矩阵公式将查询和键的实部和虚部进行旋转。
- 合并结果:
- 将旋转后的实部和虚部重新组合为完整的嵌入向量,恢复为原始的形状
(batch_size, seq_len, dim)
。
- 将旋转后的实部和虚部重新组合为完整的嵌入向量,恢复为原始的形状
3. repeat_kv
函数
这个函数的作用是将键值对(key-value pairs)进行扩展,以增加键值对的数量。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
bs, slen, n_kv_heads, head_dim = x.shape
# 如果重复次数为1,则不需要重复,直接返回原始张量
if n_rep == 1:
return x
# 对张量进行扩展和重塑操作以重复键值对
return (
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
)
假设n_rep = 3
输入为
import torch
输入张量 x
x = torch.tensor([
[
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25], [26, 27, 28, 29, 30]]
],
[
[[31, 32, 33, 34, 35], [36, 37, 38, 39, 40]],
[[41, 42, 43, 44, 45], [46, 47, 48, 49, 50]],
[[51, 52, 53, 54, 55], [56, 57, 58, 59, 60]]
]
])
输出为为
tensor([
[
[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [6, 7, 8, 9, 10], [6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15], [11, 12, 13, 14, 15], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20], [16, 17, 18, 19, 20], [16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25], [21, 22, 23, 24, 25], [21, 22, 23, 24, 25], [26, 27, 28, 29, 30], [26, 27, 28, 29, 30], [26, 27, 28, 29, 30]]
],
[
[[31, 32, 33, 34, 35], [31, 32, 33, 34, 35], [31, 32, 33, 34, 35], [36, 37, 38, 39, 40], [36, 37, 38, 39, 40], [36, 37, 38, 39, 40]],
[[41, 42, 43, 44, 45], [41, 42, 43, 44, 45], [41, 42, 43, 44, 45], [46, 47, 48, 49, 50], [46, 47, 48, 49, 50], [46, 47, 48, 49, 50]],
[[51, 52, 53, 54, 55], [51, 52, 53, 54, 55], [51, 52, 53, 54, 55], [56, 57, 58, 59, 60], [56, 57, 58, 59, 60], [56, 57, 58, 59, 60]]
]
])
最终,x 的形状从 (2, 3, 2, 5) 变为 (2, 3, 6, 5),每个头被重复了 3 次。
3.1解析
- 输入:
x
:键值对张量,形状为(batch_size, seq_len, n_kv_heads, head_dim)
。n_rep
:重复次数。
- 输出:
- 扩展后的键值对张量,形状为
(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
。
- 扩展后的键值对张量,形状为
3.2作用
- 扩展键值对:
- 如果
n_rep > 1
,则将键值对张量在头的维度上进行扩展。 - 例如,假设
x
的形状为(batch_size, seq_len, n_kv_heads, head_dim)
,重复次数为n_rep
,则扩展后的张量形状为(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
。
- 如果
- 重塑张量:
- 将扩展后的张量重新塑形,合并头的数量和重复次数的维度。
3.3总结
reshape_for_broadcast
:调整频率张量的形状,以便与目标张量进行广播操作。apply_rotary_emb
:将旋转位置嵌入应用于查询和键张量,通过旋转操作将位置信息融入嵌入向量中。repeat_kv
:扩展键值对的数量,以增加模型的计算能力。
这些函数在Transformer模型中起到了关键作用,特别是在处理长文本序列和自注意力机制中。
六、attention 模块
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# 根据是否指定n_kv_heads,确定用于键(key)和值(value)的头的数量。
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# 确保总头数可以被键值头数整除。
assert args.n_heads % self.n_kv_heads == 0
# 模型并行处理大小,默认为1。
model_parallel_size = 1
# 本地计算头数,等于总头数除以模型并行处理大小。
self.n_local_heads = args.n_heads // model_parallel_size
# 本地键值头数,等于键值头数除以模型并行处理大小。
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
# 重复次数,用于扩展键和值的尺寸。
self.n_rep = self.n_local_heads // self.n_local_kv_heads
# 每个头的维度,等于模型维度除以头的总数。
self.head_dim = args.dim // args.n_heads
# 定义权重矩阵。
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
# 输出权重矩阵。
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# 定义dropout。
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
# 保存dropout概率。
self.dropout = args.dropout
# 检查是否使用Flash Attention(需要PyTorch >= 2.0)。
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
# 若不支持Flash Attention,则使用手动实现的注意力机制,并设置mask。
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# 创建一个上三角矩阵,用于遮蔽未来信息。负无穷(inf)大填充张量
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
#torch.triu 函数将遮蔽矩阵转换为上三角矩阵,在自注意力机制中,遮蔽矩阵通常用于计算注意力分数时。具体来说,遮蔽矩阵会加到注意力分数上,使得未来的信息对应的分数变为负无穷大,从而在 Softmax 函数后权重趋近于零
mask = torch.triu(mask, diagonal=1)
# 注册为模型的缓冲区
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
bsz, seqlen, _ = x.shape
# 计算查询(Q)、键(K)、值(V)。
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# 调整形状以适应头的维度。
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 应用旋转位置嵌入(RoPE)。
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# 对键和值进行扩展以适应重复次数。
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# 将头作为批次维度处理。
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# 根据是否支持Flash Attention,选择实现方式。
if self.flash:
# 使用Flash Attention。
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# 使用手动实现的注意力机制。
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
# 恢复时间维度并合并头。
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# 最终投影回残差流。
output = self.wo(output)
output = self.resid_dropout(output)
return output
我们假设有一个具体的输入矩阵 x x x,以及相应的旋转位置嵌入 freqs_cos \text{freqs\_cos} freqs_cos和 freqs_sin \text{freqs\_sin} freqs_sin。为了简化说明,我们假设以下参数:
batch_size = 1
seq_len = 3
dim = 4
n_heads = 2
n_kv_heads = 2
(假设n_kv_heads
未指定,因此等于n_heads
)head_dim = 2
因为 head_dim = dim n_heads = 4 2 = 2 \text{head\_dim} = \frac{\text{dim}}{\text{n\_heads}} = \frac{4}{2} = 2 head_dim=n_headsdim=24=2model_parallel_size = 1
n_local_heads = n_heads / model_parallel_size = 2 / 1 = 2
n_local_kv_heads = n_kv_heads / model_parallel_size = 2 / 1 = 2
n_rep = n_local_heads / n_local_kv_heads = 2 / 2 = 1
我们假设输入矩阵 x x x为:
x = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] x = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} x=
159261037114812
旋转位置嵌入 freqs_cos \text{freqs\_cos} freqs_cos和 freqs_sin \text{freqs\_sin} freqs_sin为:
freqs_cos = [ cos ( 0 ) cos ( 1 ) cos ( 2 ) cos ( 3 ) cos ( 4 ) cos ( 5 ) ] \text{freqs\_cos} = \begin{bmatrix} \cos(0) & \cos(1) \\ \cos(2) & \cos(3) \\ \cos(4) & \cos(5) \end{bmatrix} freqs_cos=
cos(0)cos(2)cos(4)cos(1)cos(3)cos(5)
freqs_sin = [ sin ( 0 ) sin ( 1 ) sin ( 2 ) sin ( 3 ) sin ( 4 ) sin ( 5 ) ] \text{freqs\_sin} = \begin{bmatrix} \sin(0) & \sin(1) \\ \sin(2) & \sin(3) \\ \sin(4) & \sin(5) \end{bmatrix} freqs_sin=
sin(0)sin(2)sin(4)sin(1)sin(3)sin(5)
前向传播过程
好的,让我们继续之前的例子,详细说明在自注意力机制的 forward
方法中,如何从查询(Q)、键(K)、值(V)计算出最终的输出矩阵 output \text{output} output。为了简化说明,我们假设所有随机初始化的权重矩阵和旋转位置嵌入都是已知的,并且使用手动实现的注意力机制。
输入和参数回顾
- 输入矩阵 x x x:
x = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] x = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} x= 159261037114812
形状为
[batch_size, seq_len, dim]
,其中batch_size = 1
,seq_len = 3
,dim = 4
。旋转位置嵌入:
freqs_cos = [ cos ( 0 ) cos ( 1 ) cos ( 2 ) cos ( 3 ) cos ( 4 ) cos ( 5 ) ] \text{freqs\_cos} = \begin{bmatrix} \cos(0) & \cos(1) \\ \cos(2) & \cos(3) \\ \cos(4) & \cos(5) \end{bmatrix} freqs_cos=
cos(0)cos(2)cos(4)cos(1)cos(3)cos(5)
freqs_sin = [ sin ( 0 ) sin ( 1 ) sin ( 2 ) sin ( 3 ) sin ( 4 ) sin ( 5 ) ] \text{freqs\_sin} = \begin{bmatrix} \sin(0) & \sin(1) \\ \sin(2) & \sin(3) \\ \sin(4) & \sin(5) \end{bmatrix} freqs_sin=
sin(0)sin(2)sin(4)sin(1)sin(3)sin(5)
- 参数:
n_heads = 2
n_kv_heads = 2
head_dim = 2
model_parallel_size = 1
n_local_heads = 2
n_local_kv_heads = 2
n_rep = 1
前向传播过程
1. 计算查询(Q)、键(K)、值(V)
假设权重矩阵 w q wq wq、 w k wk wk和 w v wv wv是随机初始化的。为了简化说明,我们假设它们的值如下:
w q = [ 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 ] wq = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} wq=
1000010000100001
w k = [ 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 1 ] wk = \begin{bmatrix} 0 & 1 & 0 & 0 \\ 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} wk=
0100100000100001
w v = [ 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 ] wv = \begin{bmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix} wv=
0010000110000100
计算 x q xq xq、 x k xk xk和 x v xv xv:
x q = w q ( x ) = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] xq = wq(x) = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} xq=wq(x)=
159261037114812
x k = w k ( x ) = [ 2 1 3 4 6 5 7 8 10 9 11 12 ] xk = wk(x) = \begin{bmatrix} 2 & 1 & 3 & 4 \\ 6 & 5 & 7 & 8 \\ 10 & 9 & 11 & 12 \end{bmatrix} xk=wk(x)=
261015937114812
x v = w v ( x ) = [ 3 4 1 2 7 8 5 6 11 12 9 10 ] xv = wv(x) = \begin{bmatrix} 3 & 4 & 1 & 2 \\ 7 & 8 & 5 & 6 \\ 11 & 12 & 9 & 10 \end{bmatrix} xv=wv(x)=
371148121592610
2. 调整形状
将 x q xq xq、 x k xk xk和 x v xv xv调整为 [batch_size, seq_len, n_local_heads, head_dim]
:
x q = [ [ 1 2 3 4 ] , [ 5 6 7 8 ] , [ 9 10 11 12 ] ] xq = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}, \\ \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}, \\ \begin{bmatrix} 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} xq=
[1324],[5768],[9111012]
x k = [ [ 2 1 3 4 ] , [ 6 5 7 8 ] , [ 10 9 11 12 ] ] xk = \begin{bmatrix} \begin{bmatrix} 2 & 1 \\ 3 & 4 \end{bmatrix}, \\ \begin{bmatrix} 6 & 5 \\ 7 & 8 \end{bmatrix}, \\ \begin{bmatrix} 10 & 9 \\ 11 & 12 \end{bmatrix} \end{bmatrix} xk=
[2314],[6758],[1011912]
x v = [ [ 3 4 1 2 ] , [ 7 8 5 6 ] , [ 11 12 9 10 ] ] xv = \begin{bmatrix} \begin{bmatrix} 3 & 4 \\ 1 & 2 \end{bmatrix}, \\ \begin{bmatrix} 7 & 8 \\ 5 & 6 \end{bmatrix}, \\ \begin{bmatrix} 11 & 12 \\ 9 & 10 \end{bmatrix} \end{bmatrix} xv=
[3142],[7586],[1191210]
3. 应用旋转位置嵌入(RoPE)
假设旋转位置嵌入的函数 apply_rotary_emb
已经实现,我们直接给出结果:
x q = apply_rotary_emb ( x q , freqs_cos , freqs_sin ) xq = \text{apply\_rotary\_emb}(xq, \text{freqs\_cos}, \text{freqs\_sin}) xq=apply_rotary_emb(xq,freqs_cos,freqs_sin)
x k = apply_rotary_emb ( x k , freqs_cos , freqs_sin ) xk = \text{apply\_rotary\_emb}(xk, \text{freqs\_cos}, \text{freqs\_sin}) xk=apply_rotary_emb(xk,freqs_cos,freqs_sin)
4. 调整维度
将头作为批次维度处理:
x q = x q . t r a n s p o s e ( 1 , 2 ) = [ [ 1 2 5 6 9 10 ] , [ 3 4 7 8 11 12 ] ] xq = xq.transpose(1, 2) = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 5 & 6 \\ 9 & 10 \end{bmatrix}, \\ \begin{bmatrix} 3 & 4 \\ 7 & 8 \\ 11 & 12 \end{bmatrix} \end{bmatrix} xq=xq.transpose(1,2)=
1592610
,
37114812
x k = x k . t r a n s p o s e ( 1 , 2 ) = [ [ 2 1 6 5 10 9 ] , [ 3 4 7 8 11 12 ] ] xk = xk.transpose(1, 2) = \begin{bmatrix} \begin{bmatrix} 2 & 1 \\ 6 & 5 \\ 10 & 9 \end{bmatrix}, \\ \begin{bmatrix} 3 & 4 \\ 7 & 8 \\ 11 & 12 \end{bmatrix} \end{bmatrix} xk=xk.transpose(1,2)=
2610159
,
37114812
x v = x v . t r a n s p o s e ( 1 , 2 ) = [ [ 3 4 7 8 11 12 ] , [ 1 2 5 6 9 10 ] ] xv = xv.transpose(1, 2) = \begin{bmatrix} \begin{bmatrix} 3 & 4 \\ 7 & 8 \\ 11 & 12 \end{bmatrix}, \\ \begin{bmatrix} 1 & 2 \\ 5 & 6 \\ 9 & 10 \end{bmatrix} \end{bmatrix} xv=xv.transpose(1,2)=
37114812
,
1592610
5. 注意力计算
计算注意力分数:
scores = x q ⋅ x k T head_dim = 1 2 [ [ 1 ⋅ 2 + 2 ⋅ 1 1 ⋅ 6 + 2 ⋅ 5 1 ⋅ 10 + 2 ⋅ 9 5 ⋅ 2 + 6 ⋅ 1 5 ⋅ 6 + 6 ⋅ 5 5 ⋅ 10 + 6 ⋅ 9 9 ⋅ 2 + 10 ⋅ 1 9 ⋅ 6 + 10 ⋅ 5 9 ⋅ 10 + 10 ⋅ 9 ] , [ 3 ⋅ 3 + 4 ⋅ 4 3 ⋅ 7 + 4 ⋅ 8 3 ⋅ 11 + 4 ⋅ 12 7 ⋅ 3 + 8 ⋅ 4 7 ⋅ 7 + 8 ⋅ 8 7 ⋅ 11 + 8 ⋅ 12 11 ⋅ 3 + 12 ⋅ 4 11 ⋅ 7 + 12 ⋅ 8 11 ⋅ 11 + 12 ⋅ 12 ] ] \text{scores} = \frac{xq \cdot xk^T}{\sqrt{\text{head\_dim}}} = \frac{1}{\sqrt{2}} \begin{bmatrix} \begin{bmatrix} 1 \cdot 2 + 2 \cdot 1 & 1 \cdot 6 + 2 \cdot 5 & 1 \cdot 10 + 2 \cdot 9 \\ 5 \cdot 2 + 6 \cdot 1 & 5 \cdot 6 + 6 \cdot 5 & 5 \cdot 10 + 6 \cdot 9 \\ 9 \cdot 2 + 10 \cdot 1 & 9 \cdot 6 + 10 \cdot 5 & 9 \cdot 10 + 10 \cdot 9 \end{bmatrix}, \\ \begin{bmatrix} 3 \cdot 3 + 4 \cdot 4 & 3 \cdot 7 + 4 \cdot 8 & 3 \cdot 11 + 4 \cdot 12 \\ 7 \cdot 3 + 8 \cdot 4 & 7 \cdot 7 + 8 \cdot 8 & 7 \cdot 11 + 8 \cdot 12 \\ 11 \cdot 3 + 12 \cdot 4 & 11 \cdot 7 + 12 \cdot 8 & 11 \cdot 11 + 12 \cdot 12 \end{bmatrix} \end{bmatrix} scores=head_dimxq⋅xkT=21
1⋅2+2⋅15⋅2+6⋅19⋅2+10⋅11⋅6+2⋅55⋅6+6⋅59⋅6+10⋅51⋅10+2⋅95⋅10+6⋅99⋅10+10⋅9
,
3⋅3+4⋅47⋅3+8⋅411⋅3+12⋅43⋅7+4⋅87⋅7+8⋅811⋅7+12⋅83⋅11+4⋅127⋅11+8⋅1211⋅11+12⋅12
6.添加因果遮蔽矩阵
因果遮蔽矩阵为:
mask = [ 0 − ∞ − ∞ 0 0 − ∞ 0 0 0 ] \text{mask} = \begin{bmatrix} 0 & -\infty & -\infty \\ 0 & 0 & -\infty \\ 0 & 0 & 0 \end{bmatrix} mask=
000−∞00−∞−∞0
应用因果遮蔽:
scores 1 = 1 2 [ 4 − ∞ − ∞ 16 60 − ∞ 28 104 180 ] \text{scores}_1 = \frac{1}{\sqrt{2}} \begin{bmatrix} 4 & -\infty & -\infty \\ 16 & 60 & -\infty \\ 28 & 104 & 180 \end{bmatrix} scores1=21
41628−∞60104−∞−∞180
scores 2 = 1 2 [ 25 − ∞ − ∞ 53 113 − ∞ 81 173 265 ] \text{scores}_2 = \frac{1}{\sqrt{2}} \begin{bmatrix} 25 & -\infty & -\infty \\ 53 & 113 & -\infty \\ 81 & 173 & 265 \end{bmatrix} scores2=21
255381−∞113173−∞−∞265
应用 Softmax
假设 Softmax 的结果是随机生成的,但符合 Softmax 的性质(即每一行的值加起来为 1,并且 − ∞ -\infty −∞对应的值为 0)。我们将这些值代入后续的计算中。
添加因果遮蔽矩阵并应用 Softmax
假设 Softmax 的结果如下(随机生成,但符合因果遮蔽的要求):
对于第一个头:
scores 1 = softmax ( 1 2 [ 4 − ∞ − ∞ 16 60 − ∞ 28 104 180 ] ) = [ 1.0 0 0 0.2 0.8 0 0.1 0.3 0.6 ] \text{scores}_1 = \text{softmax}\left(\frac{1}{\sqrt{2}} \begin{bmatrix} 4 & -\infty & -\infty \\ 16 & 60 & -\infty \\ 28 & 104 & 180 \end{bmatrix}\right) = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.2 & 0.8 & 0 \\ 0.1 & 0.3 & 0.6 \end{bmatrix} scores1=softmax
21
41628−∞60104−∞−∞180
=
1.00.20.100.80.3000.6
对于第二个头:
scores 2 = softmax ( 1 2 [ 25 − ∞ − ∞ 53 113 − ∞ 81 173 265 ] ) = [ 1.0 0 0 0.3 0.7 0 0.1 0.2 0.7 ] \text{scores}_2 = \text{softmax}\left(\frac{1}{\sqrt{2}} \begin{bmatrix} 25 & -\infty & -\infty \\ 53 & 113 & -\infty \\ 81 & 173 & 265 \end{bmatrix}\right) = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.3 & 0.7 & 0 \\ 0.1 & 0.2 & 0.7 \end{bmatrix} scores2=softmax
21
255381−∞113173−∞−∞265
=
1.00.30.100.70.2000.7
应用 Dropout
假设 Dropout 概率为 0.1,随机生成一个与分数矩阵形状相同的掩码矩阵。为了简化,我们假设掩码矩阵如下(随机生成,但符合 90% 保留的概率):
对于第一个头:
dropout_mask 1 = [ 1 0 0 1 1 0 1 1 1 ] \text{dropout\_mask}_1 = \begin{bmatrix} 1 & 0 & 0 \\ 1 & 1 & 0 \\ 1 & 1 & 1 \end{bmatrix} dropout_mask1=
111011001
对于第二个头:
dropout_mask 2 = [ 1 0 0 1 1 0 1 1 1 ] \text{dropout\_mask}_2 = \begin{bmatrix} 1 & 0 & 0 \\ 1 & 1 & 0 \\ 1 & 1 & 1 \end{bmatrix} dropout_mask2=
111011001
应用 Dropout:
scores 1 = scores 1 × dropout_mask 1 = [ 1.0 0 0 0.2 0.8 0 0.1 0.3 0.6 ] \text{scores}_1 = \text{scores}_1 \times \text{dropout\_mask}_1 = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.2 & 0.8 & 0 \\ 0.1 & 0.3 & 0.6 \end{bmatrix} scores1=scores1×dropout_mask1=
1.00.20.100.80.3000.6
scores 2 = scores 2 × dropout_mask 2 = [ 1.0 0 0 0.3 0.7 0 0.1 0.2 0.7 ] \text{scores}_2 = \text{scores}_2 \times \text{dropout\_mask}_2 = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.3 & 0.7 & 0 \\ 0.1 & 0.2 & 0.7 \end{bmatrix} scores2=scores2×dropout_mask2= 1.00.30.100.70.2000.7
计算加权和
假设值矩阵 x v xv xv为:
x v = [ [ 3 4 1 2 ] , [ 7 8 5 6 ] , [ 11 12 9 10 ] ] xv = \begin{bmatrix} \begin{bmatrix} 3 & 4 \\ 1 & 2 \end{bmatrix}, \\ \begin{bmatrix} 7 & 8 \\ 5 & 6 \end{bmatrix}, \\ \begin{bmatrix} 11 & 12 \\ 9 & 10 \end{bmatrix} \end{bmatrix} xv=
[3142],[7586],[1191210]
计算第一个头的加权和:
output 1 = scores 1 ⋅ [ 3 4 1 2 ] = [ 1.0 ⋅ 3 + 0 ⋅ 1 + 0 ⋅ 11 1.0 ⋅ 4 + 0 ⋅ 2 + 0 ⋅ 12 0.2 ⋅ 3 + 0.8 ⋅ 1 + 0 ⋅ 11 0.2 ⋅ 4 + 0.8 ⋅ 2 + 0 ⋅ 12 0.1 ⋅ 3 + 0.3 ⋅ 1 + 0.6 ⋅ 11 0.1 ⋅ 4 + 0.3 ⋅ 2 + 0.6 ⋅ 12 ] \text{output}_1 = \text{scores}_1 \cdot \begin{bmatrix} 3 & 4 \\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 1.0 \cdot 3 + 0 \cdot 1 + 0 \cdot 11 & 1.0 \cdot 4 + 0 \cdot 2 + 0 \cdot 12 \\ 0.2 \cdot 3 + 0.8 \cdot 1 + 0 \cdot 11 & 0.2 \cdot 4 + 0.8 \cdot 2 + 0 \cdot 12 \\ 0.1 \cdot 3 + 0.3 \cdot 1 + 0.6 \cdot 11 & 0.1 \cdot 4 + 0.3 \cdot 2 + 0.6 \cdot 12 \end{bmatrix} output1=scores1⋅[3142]=
1.0⋅3+0⋅1+0⋅110.2⋅3+0.8⋅1+0⋅110.1⋅3+0.3⋅1+0.6⋅111.0⋅4+0⋅2+0⋅120.2⋅4+0.8⋅2+0⋅120.1⋅4+0.3⋅2+0.6⋅12
计算每个元素:
output 1 = [ 3 4 0.6 + 0.8 0.8 + 1.6 0.3 + 0.3 + 6.6 0.4 + 0.6 + 7.2 ] = [ 3 4 1.4 2.4 7.2 8.2 ] \text{output}_1 = \begin{bmatrix} 3 & 4 \\ 0.6 + 0.8 & 0.8 + 1.6 \\ 0.3 + 0.3 + 6.6 & 0.4 + 0.6 + 7.2 \end{bmatrix} = \begin{bmatrix} 3 & 4 \\ 1.4 & 2.4 \\ 7.2 & 8.2 \end{bmatrix} output1=
30.6+0.80.3+0.3+6.640.8+1.60.4+0.6+7.2
=
31.47.242.48.2
计算第二个头的加权和:
output 2 = scores 2 ⋅ [ 7 8 5 6 ] = [ 1.0 ⋅ 7 + 0 ⋅ 5 + 0 ⋅ 11 1.0 ⋅ 8 + 0 ⋅ 6 + 0 ⋅ 12 0.3 ⋅ 7 + 0.7 ⋅ 5 + 0 ⋅ 11 0.3 ⋅ 8 + 0.7 ⋅ 6 + 0 ⋅ 12 0.1 ⋅ 7 + 0.2 ⋅ 5 + 0.7 ⋅ 11 0.1 ⋅ 8 + 0.2 ⋅ 6 + 0.7 ⋅ 12 ] \text{output}_2 = \text{scores}_2 \cdot \begin{bmatrix} 7 & 8 \\ 5 & 6 \end{bmatrix} = \begin{bmatrix} 1.0 \cdot 7 + 0 \cdot 5 + 0 \cdot 11 & 1.0 \cdot 8 + 0 \cdot 6 + 0 \cdot 12 \\ 0.3 \cdot 7 + 0.7 \cdot 5 + 0 \cdot 11 & 0.3 \cdot 8 + 0.7 \cdot 6 + 0 \cdot 12 \\ 0.1 \cdot 7 + 0.2 \cdot 5 + 0.7 \cdot 11 & 0.1 \cdot 8 + 0.2 \cdot 6 + 0.7 \cdot 12 \end{bmatrix} output2=scores2⋅[7586]=
1.0⋅7+0⋅5+0⋅110.3⋅7+0.7⋅5+0⋅110.1⋅7+0.2⋅5+0.7⋅111.0⋅8+0⋅6+0⋅120.3⋅8+0.7⋅6+0⋅120.1⋅8+0.2⋅6+0.7⋅12
计算每个元素:
output 2 = [ 7 8 2.1 + 3.5 2.4 + 4.2 0.7 + 1.0 + 7.7 0.8 + 1.2 + 8.4 ] = [ 7 8 5.6 6.6 9.4 10.4 ] \text{output}_2 = \begin{bmatrix} 7 & 8 \\ 2.1 + 3.5 & 2.4 + 4.2 \\ 0.7 + 1.0 + 7.7 & 0.8 + 1.2 + 8.4 \end{bmatrix} = \begin{bmatrix} 7 & 8 \\ 5.6 & 6.6 \\ 9.4 & 10.4 \end{bmatrix} output2=
72.1+3.50.7+1.0+7.782.4+4.20.8+1.2+8.4
=
75.69.486.610.4
6. 恢复维度并投影
恢复时间维度并合并头
将头作为批次维度处理后,恢复时间维度并合并头:
output = [ output 1 output 2 ] . t r a n s p o s e ( 1 , 2 ) . c o n t i g u o u s ( ) . v i e w ( 1 , 3 , 4 ) \text{output} = \begin{bmatrix} \text{output}_1 & \text{output}_2 \end{bmatrix}.transpose(1, 2).contiguous().view(1, 3, 4) output=[output1output2].transpose(1,2).contiguous().view(1,3,4)
合并后的矩阵为:
output = [ 3 4 7 8 1.4 2.4 5.6 6.6 7.2 8.2 9.4 10.4 ] \text{output} = \begin{bmatrix} 3 & 4 & 7 & 8 \\ 1.4 & 2.4 & 5.6 & 6.6 \\ 7.2 & 8.2 & 9.4 & 10.4 \end{bmatrix} output=
31.47.242.48.275.69.486.610.4
最终投影回残差流
假设输出权重矩阵 w o wo wo为:
w o = [ 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 ] wo = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} wo=
1000010000100001
最终输出:
output = w o ( output ) = [ 3 4 7 8 1.4 2.4 5.6 6.6 7.2 8.2 9.4 10.4 ] \text{output} = wo(\text{output}) = \begin{bmatrix} 3 & 4 & 7 & 8 \\ 1.4 & 2.4 & 5.6 & 6.6 \\ 7.2 & 8.2 & 9.4 & 10.4 \end{bmatrix} output=wo(output)=
31.47.242.48.275.69.486.610.4
假设 Dropout 概率为 0.1,随机生成一个与输出矩阵形状相同的掩码矩阵。为了简化,我们假设掩码矩阵如下(随机生成,但符合 90% 保留的概率):
dropout_mask = [ 1 1 1 1 1 1 1 1 1 1 1 1 ] \text{dropout\_mask} = \begin{bmatrix} 1 & 1 & 1 & 1 \\ 1 & 1 & 1 & 1 \\ 1 & 1 & 1 & 1 \end{bmatrix} dropout_mask=
111111111111
应用 Dropout:
output = output × dropout_mask = [ 3 4 7 8 1.4 2.4 5.6 6.6 7.2 8.2 9.4 10.4 ] \text{output} = \text{output} \times \text{dropout\_mask} = \begin{bmatrix} 3 & 4 & 7 & 8 \\ 1.4 & 2.4 & 5.6 & 6.6 \\ 7.2 & 8.2 & 9.4 & 10.4 \end{bmatrix} output=output×dropout_mask=
31.47.242.48.275.69.486.610.4
最终输出
最终输出矩阵 output \text{output} output为:
output = [ 3 4 7 8 1.4 2.4 5.6 6.6 7.2 8.2 9.4 10.4 ] \text{output} = \begin{bmatrix} 3 & 4 & 7 & 8 \\ 1.4 & 2.4 & 5.6 & 6.6 \\ 7.2 & 8.2 & 9.4 & 10.4 \end{bmatrix} output=
31.47.242.48.275.69.486.610.4
这个矩阵就是经过自注意力机制处理后的结果。
总结
通过上述步骤,我们完成了自注意力机制的计算过程,从输入矩阵 x x x到最终的输出矩阵 output \text{output} output。这个过程体现了自注意力机制的核心思想:通过计算输入序列内部的依赖关系,生成加权后的特征表示。
七、什么是时间维度
好的,我来详细解释一下“时间维度”在自注意力机制中的含义,以及它在代码中的体现。
时间维度的含义
在自然语言处理(NLP)任务中,输入数据通常是序列化的,例如文本句子、时间序列数据等。每个序列可以看作是在时间轴上的一系列元素。因此,序列中的每个元素(如单词、字符或时间点)可以被视为一个时间步(time step)。这种时间步的排列就构成了“时间维度”。
以文本处理为例:
- 假设输入是一个句子:“Hello world!”。
- 这个句子可以被表示为一个序列:
["Hello", "world", "!"]
。 - 每个单词(或标记)对应一个时间步,整个序列的长度(即单词数量)就是时间维度的大小。
在矩阵表示中,时间维度通常对应于矩阵的第二个维度(假设第一个维度是批次维度)。例如,对于输入矩阵 x x x的形状为 [batch_size, seq_len, dim]
,其中:
batch_size
是批次大小,表示一次处理的序列数量。seq_len
是序列长度,即时间维度的大小。dim
是每个时间步的特征维度。
时间维度在代码中的体现
在代码中,时间维度主要体现在以下几个方面:
1. 输入矩阵的形状
输入矩阵 x x x的形状为 [batch_size, seq_len, dim]
。其中:
seq_len
就是时间维度的大小,表示序列的长度。- 每个时间步对应一个向量,向量的维度为
dim
。
以代码中的输入矩阵为例:
x = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] x = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} x=
159261037114812
- 假设
batch_size = 1
,seq_len = 3
,dim = 4
。 - 这个矩阵表示一个长度为 3 的序列,每个时间步的特征维度为 4。
2. 时间维度的调整
在自注意力机制中,时间维度可能会在计算过程中被调整或重新排列。例如:
- 在计算查询(Q)、键(K)、值(V)时,时间维度保持不变:
x q = w q ( x ) 形状为 [ b a t c h s i z e , s e q l e n , n l o c a l h e a d s × h e a d d i m ] xq = wq(x) \quad \text{形状为} \quad [batch_size, seq_len, n_local_heads \times head_dim] xq=wq(x)形状为[batchsize,seqlen,nlocalheads×headdim]
x k = w k ( x ) 形状为 [ b a t c h s i z e , s e q l e n , n l o c a l k v h e a d s × h e a d d i m ] xk = wk(x) \quad \text{形状为} \quad [batch_size, seq_len, n_local_kv_heads \times head_dim] xk=wk(x)形状为[batchsize,seqlen,nlocalkvheads×headdim]
x v = w v ( x ) 形状为 [ b a t c h s i z e , s e q l e n , n l o c a l k v h e a d s × h e a d d i m ] xv = wv(x) \quad \text{形状为} \quad [batch_size, seq_len, n_local_kv_heads \times head_dim] xv=wv(x)形状为[batchsize,seqlen,nlocalkvheads×headdim]
- 在将头作为批次维度处理时,时间维度会被调整到第二个位置:
x q = x q . t r a n s p o s e ( 1 , 2 ) 形状为 [ b a t c h s i z e , n l o c a l h e a d s , s e q l e n , h e a d d i m ] xq = xq.transpose(1, 2) \quad \text{形状为} \quad [batch_size, n_local_heads, seq_len, head_dim] xq=xq.transpose(1,2)形状为[batchsize,nlocalheads,seqlen,headdim]
x k = x k . t r a n s p o s e ( 1 , 2 ) 形状为 [ b a t c h s i z e , n l o c a l k v h e a d s , s e q l e n , h e a d d i m ] xk = xk.transpose(1, 2) \quad \text{形状为} \quad [batch_size, n_local_kv_heads, seq_len, head_dim] xk=xk.transpose(1,2)形状为[batchsize,nlocalkvheads,seqlen,headdim]
x v = x v . t r a n s p o s e ( 1 , 2 ) 形状为 [ b a t c h s i z e , n l o c a l k v h e a d s , s e q l e n , h e a d d i m ] xv = xv.transpose(1, 2) \quad \text{形状为} \quad [batch_size, n_local_kv_heads, seq_len, head_dim] xv=xv.transpose(1,2)形状为[batchsize,nlocalkvheads,seqlen,headdim]
3. 时间维度的恢复
在完成注意力计算后,时间维度需要恢复到原始位置。例如:
- 在 Flash Attention 或手动实现的注意力机制中,计算结果的形状为
[batch_size, n_local_heads, seq_len, head_dim]
。 - 通过
transpose
和view
操作恢复时间维度:
o u t p u t = o u t p u t . t r a n s p o s e ( 1 , 2 ) . c o n t i g u o u s ( ) . v i e w ( b s z , s e q l e n , − 1 ) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output=output.transpose(1,2).contiguous().view(bsz,seqlen,−1)
transpose(1, 2)
将头的维度和时间维度交换,形状变为[batch_size, seq_lenn_local_heads, head_dim]
。view(bsz, seqlen, -1)
将头的维度和特征维度合并,形状变为[batch_size, seq_len, dim]
。
4. 因果遮蔽
在自注意力机制中,因果遮蔽(Causal Masking)用于确保模型在计算某个时间步的输出时,只能看到该时间步之前的信息。这在语言模型中非常重要,因为语言模型需要生成符合语法和语义的序列。
因果遮蔽矩阵的形状为 [1, 1, seq_len, seq_len]
,例如:
mask = [ 0 − ∞ − ∞ 0 0 − ∞ 0 0 0 ] \text{mask} = \begin{bmatrix} 0 & -\infty & -\infty \\ 0 & 0 & -\infty \\ 0 & 0 & 0 \end{bmatrix} mask=
000−∞00−∞−∞0
- 对角线上的值为 0,表示当前时间步可以关注自身。
- 上三角部分的值为 − ∞ -\infty −∞,表示当前时间步不能关注未来的时间步。
在代码中,因果遮蔽矩阵被添加到注意力分数中:
scores = scores + mask [ : , : , : s e q l e n , : s e q l e n ] \text{scores} = \text{scores} + \text{mask}[:, :, :seqlen, :seqlen] scores=scores+mask[:,:,:seqlen,:seqlen]
这样,未来时间步的注意力分数会被设置为 − ∞ -\infty −∞,在 Softmax 后权重趋近于零,从而实现因果遮蔽。
6.总结
时间维度在自注意力机制中起着关键作用:
- 它表示序列的长度,即时间步的数量。
- 在计算过程中,时间维度可能会被调整或重新排列,以适应不同阶段的计算需求。
- 因果遮蔽机制通过时间维度来确保模型只能关注当前及之前的时间步,从而生成符合逻辑的序列。
通过理解时间维度的作用,我们可以更好地理解自注意力机制的实现和优化。
八、MLP(多层感知机,Multilayer Perceptron)
class MLP(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
# 如果没有指定隐藏层的维度,我们将其设置为输入维度的4倍
# 然后将其减少到2/3,最后确保它是multiple_of的倍数
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# 定义第一层线性变换,从输入维度到隐藏维度
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
# 定义第二层线性变换,从隐藏维度到输入维度
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
# 定义第三层线性变换,从输入维度到隐藏维度
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
# 定义dropout层,用于防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 前向传播函数
# 首先,输入x通过第一层线性变换和SILU激活函数
# 然后,结果乘以输入x通过第三层线性变换的结果
# 最后,通过第二层线性变换和dropout层
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
MLP(多层感知机,Multilayer Perceptron)是一种基础的前馈神经网络,它由多个层组成,每层由多个神经元构成。MLP 通常包含以下几层:
- 输入层(Input Layer):接收原始数据输入。
- 隐藏层(Hidden Layers):一个或多个中间层,每层都使用激活函数来引入非线性,从而让网络能够学习和执行更复杂的任务。
- 输出层(Output Layer):产生最终的预测结果。
MLP 结构
MLP 的结构可以用以下公式表示:
- 对于输入层到第一个隐藏层的转换:
h ( 1 ) = f ( W ( 1 ) x + b ( 1 ) ) h^{(1)} = f\left(W^{(1)}x + b^{(1)}\right) h(1)=f(W(1)x+b(1))
其中, x x x是输入向量, W ( 1 ) W^{(1)} W(1)是第一层的权重矩阵, b ( 1 ) b^{(1)} b(1)是第一层的偏置向量, f f f是激活函数, h ( 1 ) h^{(1)} h(1)是第一层的输出。
- 对于第一个隐藏层到第二个隐藏层的转换(如果有多个隐藏层):
h ( 2 ) = f ( W ( 2 ) h ( 1 ) + b ( 2 ) ) h^{(2)} = f\left(W^{(2)}h^{(1)} + b^{(2)}\right) h(2)=f(W(2)h(1)+b(2))
其中, h ( 1 ) h^{(1)} h(1)是第一层的输出, W ( 2 ) W^{(2)} W(2)是第二层的权重矩阵, b ( 2 ) b^{(2)} b(2)是第二层的偏置向量, h ( 2 ) h^{(2)} h(2)是第二层的输出。
- 对于最后一个隐藏层到输出层的转换:
y = W ( L ) h ( L − 1 ) + b ( L ) y = W^{(L)}h^{(L-1)} + b^{(L)} y=W(L)h(L−1)+b(L)
其中, h ( L − 1 ) h^{(L-1)} h(L−1)是最后一个隐藏层的输出, W ( L ) W^{(L)} W(L)是输出层的权重矩阵, b ( L ) b^{(L)} b(L)是输出层的偏置向量, y y y是最终的输出向量。
激活函数
激活函数 f f f是引入非线性的关键,常见的激活函数包括:
- Sigmoid: f ( x ) = 1 1 + e − x f(x) = \frac{1}{1 + e^{-x}} f(x)=1+e−x1
- Tanh: f ( x ) = e x − e − x e x + e − x f(x) = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}} f(x)=ex+e−xex−e−x
- ReLU(Rectified Linear Unit): f ( x ) = max ( 0 , x ) f(x) = \max(0, x) f(x)=max(0,x)
- Leaky ReLU: f ( x ) = max ( 0.01 x , x ) f(x) = \max(0.01x, x) f(x)=max(0.01x,x)
SiLU(Sigmoid Linear Unit) 激活函数,也被称为Swish激活函数,是一种自适应激活函数,最早由Google Brain在2017年引入。SiLU激活函数的定义如下:
SiLU ( x ) = x ⋅ σ ( x ) \text{SiLU}(x) = x \cdot \sigma(x) SiLU(x)=x⋅σ(x)
其中, σ ( x ) \sigma(x) σ(x)是标准的Sigmoid函数,它的值在0和1之间。SiLu函数的特点包括非线性、连续可导,并且在负无穷到正无穷的范围内都有定义。SiLU函数可以看作是平滑的ReLU激活函数。它既有ReLU激活函数的一些优点(例如能够缓解梯度消失问题),又能解决ReLU函数的一些缺点(例如ReLU函数不是零中心的,且在负数部分的梯度为零)。此外,SiLu函数还是平滑函数,这意味着它在整个定义域内都有导数,有利于优化。
Sigmoid函数是常见的激活函数,其表达式为:
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+e−x1
SiLU 激活函数可以看作是 sigmoid 函数和输入值的乘积。它将输入值乘以一个介于 0 和 1 之间的值,从而实现非线性激活。SiLU 激活函数具有以下几个重要性质:
- 平滑性: SiLU 激活函数是平滑的,这意味着它的导数在所有点都存在。这使得 SiLU 激活函数更易于优化,并有助于避免梯度消失问题。
- 非单调性: SiLU 激活函数是非单调的,这意味着它在某些区间内是单调递增的,而在其他区间内是单调递减的。这使得 SiLU 激活函数可以更好地学习复杂的数据模式。
- 零中心性: SiLU 激活函数在零点处取值为零。这使得 SiLU 激活函数可以更好地处理输入数据的分布,并避免梯度爆炸问题。
DecoderLayer解码器层的实现
class DecoderLayer(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
# 定义多头注意力的头数
self.n_heads = args.n_heads
# 定义输入维度
self.dim = args.dim
# 定义每个头的维度,等于输入维度除以头数
self.head_dim = args.dim // args.n_heads
# 定义LLaMA2Attention对象,用于进行多头注意力计算
self.attention = Attention(args)
# 定义LLaMAMLP对象,用于进行前馈神经网络计算
self.feed_forward = MLP(
dim=args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
# 定义层的ID
self.layer_id = layer_id
# 定义注意力计算的归一化层
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
# 定义前馈神经网络计算的归一化层
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin):
# 前向传播函数
# 首先,输入x经过注意力归一化层,然后进行注意力计算,结果与输入x相加得到h
# 然后,h经过前馈神经网络归一化层,然后进行前馈神经网络计算,结果与h相加得到输出
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
九、transformer总架构
class Transformer(nn.Module):
last_loss: Optional[torch.Tensor]
def __init__(self, args: ModelArgs):
super().__init__()
# 初始化模型参数
self.args = args
# 词汇表大小
self.vocab_size = args.vocab_size
# 层数
self.n_layers = args.n_layers
# 词嵌入层
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
# Dropout层
self.dropout = nn.Dropout(args.dropout)
# Decoder层
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(DecoderLayer(layer_id, args))
# 归一化层
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
# 输出层
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
# 将词嵌入层的权重与输出层的权重共享
self.tok_embeddings.weight = self.output.weight
# 预计算相对位置嵌入的频率
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
# 初始化所有权重
self.apply(self._init_weights)
# 对残差投影进行特殊的缩放初始化
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
# 初始化最后一次前向传播的损失属性
self.last_loss = None
def _init_weights(self, module):
# 初始化权重的函数
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
# 前向传播函数
_bsz, seqlen = tokens.shape
# 通过词嵌入层和Dropout层
h = self.tok_embeddings(tokens)
h = self.dropout(h)
# 获取相对位置嵌入的频率
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
# 通过Decoder层
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
# 通过归一化层
h = self.norm(h)
if targets is not None:
# 如果给定了目标,计算损失
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# 推理时的小优化:只对最后一个位置的输出进行前向传播
logits = self.output(h[:, [-1], :])
self.last_loss = None
return logits
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# 获取所有需要更新的参数
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
# 将参数分为需要权重衰减和不需要权重衰减的两组
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
# 打印参数数量信息
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# 根据设备类型选择使用标准 AdamW 或其融合版本
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
def estimate_mfu(self, fwdbwd_per_iter, dt):
""" 估计模型的 FLOPs 利用率 (MFU) 单位:A100 bfloat16 的峰值 FLOPS """
# 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B)
# PaLM: Scaling Language Modeling with Pathways: https://arxiv.org/abs/2204.02311
N = sum(p.numel() for p in self.parameters())
cfg = self.args
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
flops_per_token = 6*N + 12*L*H*Q*T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例
flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs
flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu
@torch.inference_mode()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
给定输入序列 idx(形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
在 model.eval() 模式下运行。效率较低的采样版本,没有使用键k/v cache。
"""
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond)
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
if temperature == 0.0:
# 选择最有可能的索引
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# 缩放 logits 并应用 softmax
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# 将采样的索引添加到序列中并继续
idx = torch.cat((idx, idx_next), dim=1)
return idx
这段代码定义了一个名为 Transformer
的类,它是一个神经网络模型,继承自 PyTorch 的 nn.Module
类。这个模型看起来是为了处理自然语言处理(NLP)任务,特别是文本生成任务,如机器翻译或文本摘要。以下是代码中关键部分的解释:
初始化方法
__init__
:这个方法设置了模型的主要参数,包括词汇表大小(vocab_size
)、层数(n_layers
)、嵌入层(tok_embeddings
)、Dropout层(dropout
)、解码器层(layers
)、归一化层(norm
)和输出层(output
)。它还注册了一些缓冲区(buffers)来保存相对位置编码的频率,并对模型的权重进行初始化。初始化权重
_init_weights
:这是一个辅助函数,用于初始化模型中的权重。前向传播方法
forward
:这个方法定义了模型的前向传播逻辑。它接受输入tokens
和目标targets
,通过嵌入层和Dropout层,然后通过多个解码器层,最后通过输出层生成输出。如果提供了目标targets
,则计算交叉熵损失。配置优化器方法
configure_optimizers
:这个方法用于设置优化器。它将模型参数分为两组:需要权重衰减(weight decay)和不需要权重衰减的参数,并创建一个 AdamW 优化器。估计模型FLOPs利用率
estimate_mfu
:这个方法用于估计模型的计算效率,以 A100 bfloat16 峰值 FLOPS 的比例表示。生成方法
generate
:这个方法用于生成文本。它接受一个初始序列idx
,生成新 token 的最大数量max_new_tokens
,以及控制生成过程的温度参数temperature
和采样策略top_k
。这个方法在模型处于推理模式(torch.inference_mode()
)下运行。
整体来说,这个 Transformer
类实现了一个基于 Transformer 架构的序列到序列模型,它可以用于多种 NLP 任务。代码中包含了模型的定义、权重初始化、前向传播逻辑、优化器配置以及文本生成逻辑。