一、Transformer 整体架构
Transformer 由 Encoder 和 Decoder 堆叠组成,每个 Encoder/Decoder 层包含以下核心模块:
- Encoder 层:Multi-Head Self-Attention → Add & LayerNorm → Feed-Forward → Add & LayerNorm
- Decoder 层:Masked Multi-Head Self-Attention → Add & LayerNorm → Cross-Attention (Encoder-Decoder) → Add & LayerNorm → Feed-Forward → Add & LayerNorm
二、核心组件代码实现与解析
1. 位置编码 (Positional Embeddings)
原理:将序列中每个 token 的位置信息编码为向量,与词向量相加后输入模型。常用正弦/余弦函数或可学习参数实现。
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1) # (max_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置
self.register_buffer('pe', pe) # 非可学习参数
def forward(self, x):
# x: (B, seq_len, d_model)
x = x + self.pe[:x.size(1), :] # 仅取前seq_len个位置的编码
return x
代码解析:
- 固定模式编码:使用正弦函数生成位置编码,区分奇偶位置维度。
- 可扩展性:
max_len
预设最大序列长度,适用于不同输入长度。 - 加法融合:直接将位置编码与词嵌入相加,保留位置信息。
2. 层归一化 (Layer Normalization)
原理:对每个样本的特征维度(而非批量维度)进行归一化,稳定训练。
# PyTorch 内置实现
layer_norm = nn.LayerNorm(d_model)
关键特性:
- 归一化维度:沿特征维度(
d_model
)计算均值和方差。 - 残差连接:常与残差结构配合使用(如
x + sublayer(x)
)。
3. 前馈网络 (Feed-Forward Layer)
原理:通过两个线性变换与激活函数,增强模型非线性表达能力。
class FeedForward(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
代码解析:
- 维度扩展:先升维(
d_model → dim_feedforward
)再降维,增强特征交互。 - 激活函数:使用 ReLU 引入非线性。
4. Encoder 层
class EncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead)
self.ffn = FeedForward(d_model, dim_feedforward, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask=None):
# 1. Self-Attention
attn_output, _ = self.self_attn(src, src, src, mask=src_mask)
src = src + self.dropout(attn_output) # 残差连接
src = self.norm1(src)
# 2. Feed-Forward
ffn_output = self.ffn(src)
src = src + self.dropout(ffn_output) # 残差连接
src = self.norm2(src)
return src
代码解析:
- 残差连接:
x + dropout(sublayer(x))
缓解梯度消失。 - 层归一化顺序:先残差连接后归一化(Post-LN 结构)。
5. Decoder 层
class DecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead)
self.cross_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead)
self.ffn = FeedForward(d_model, dim_feedforward, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
# 1. Masked Self-Attention
attn_output, _ = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)
tgt = tgt + self.dropout(attn_output)
tgt = self.norm1(tgt)
# 2. Encoder-Decoder Cross Attention
attn_output, _ = self.cross_attn(tgt, memory, memory, mask=memory_mask)
tgt = tgt + self.dropout(attn_output)
tgt = self.norm2(tgt)
# 3. Feed-Forward
ffn_output = self.ffn(tgt)
tgt = tgt + self.dropout(ffn_output)
tgt = self.norm3(tgt)
return tgt
代码解析:
- Masked Self-Attention:使用
tgt_mask
遮挡未来词(如对角线矩阵)。 - Cross-Attention:
tgt
作为 Query,memory
(Encoder 输出)作为 Key/Value。
6. 完整 Transformer 模型
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8,
num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
super().__init__()
# 1. 嵌入层与位置编码
self.src_embed = nn.Embedding(src_vocab_size, d_model)
self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
# 2. Encoder 和 Decoder 堆叠
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, nhead, dim_feedforward, dropout)
for _ in range(num_encoder_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, nhead, dim_feedforward, dropout)
for _ in range(num_decoder_layers)
])
# 3. 输出层
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# 编码器
src_emb = self.pos_encoder(self.src_embed(src))
for layer in self.encoder_layers:
src_emb = layer(src_emb, src_mask)
# 解码器
tgt_emb = self.pos_encoder(self.tgt_embed(tgt))
for layer in self.decoder_layers:
tgt_emb = layer(tgt_emb, src_emb, tgt_mask, None)
# 输出预测
output = self.fc_out(tgt_emb)
return output
关键参数:
src_vocab_size
:源语言词表大小。tgt_vocab_size
:目标语言词表大小。d_model
:模型维度(通常与词嵌入维度一致)。nhead
:多头注意力头数。num_encoder/decoder_layers
:Encoder/Decoder 层数。
三、使用示例
# 参数设置
src_vocab_size = 10000 # 源语言词表大小
tgt_vocab_size = 8000 # 目标语言词表大小
d_model = 512
nhead = 8
num_layers = 6
batch_size = 32
seq_len = 50
# 初始化模型
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, nhead, num_layers, num_layers)
# 随机生成输入
src = torch.randint(0, src_vocab_size, (batch_size, seq_len)) # (B, seq_len)
tgt = torch.randint(0, tgt_vocab_size, (batch_size, seq_len)) # (B, seq_len)
# 前向计算
output = model(src, tgt)
print("Output shape:", output.shape) # (B, seq_len, tgt_vocab_size)