本文使用 Pytorch 库手动实现了传统 Transformer 模型中的多头自注意力机制、残差连接和层归一化、前馈层、编码器、解码器等子模块,进而实现了对 Transformer 模型的构建。
"""
@Title: 解析 Transformer
@Time: 2025/5/10
@Author: Michael Jie
"""
import math
import torch
import torch.nn.functional as F
from torch import nn, Tensor
# 缩放点积注意力机制 (Scaled Dot-Product Attention)
class Attention(nn.Module):
def __init__(self, causal: bool = True) -> None:
"""
注意力公式:Attention(Q, K, V) = softmax(Q · K / sqrt(d_k)) · V
Args:
causal: 是否自动生成因果掩码,默认为 True
"""
super(Attention, self).__init__()
self.causal = causal
def forward(self,
q: Tensor,
k: Tensor,
v: Tensor,
padding_mask: Tensor = None,
attn_mask: Tensor = None) -> tuple[Tensor, Tensor]:
"""
填充掩码:处理变长序列,避免填充影响注意力计算
因果掩码:防止解码器在训练时看到未来的信息
Args:
q: 查询 shape=(..., seq_len_q, d_k)
k: 键 shape=(..., seq_len_k, d_k)
v: 值 shape=(..., seq_len_k, d_v)
padding_mask: 填充掩码 shape=(..., seq_len_k)
attn_mask: 因果掩码 shape=(..., seq_len_q, seq_len_k)
Returns:
output: 输出 shape=(..., seq_len_q, d_v)
weights: 注意力权重 shape=(..., seq_len_q, seq_len_k)
"""
# 注意力分数
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 应用填充掩码
if padding_mask is not None:
# 广播 (..., 1, seq_len_k)
scores = scores.masked_fill(padding_mask.unsqueeze(-2), float("-inf"))
# 自动生成因果掩码,优先使用自定义的因果掩码
seq_len_q, seq_len_k = q.size(-2), k.size(-2)
if self.causal and attn_mask is None:
attn_mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=1).bool()
# 应用因果掩码
if attn_mask is not None:
scores = scores.masked_fill(attn_mask, float("-inf"))
# 注意力权重
weights = F.softmax(scores, dim=-1)
# 再次应用填充掩码,确保填充位置的注意力权重为 0
if padding_mask is not None:
weights = weights.masked_fill(padding_mask.unsqueeze(-2), 0)
# 乘以 v 得到输出
output = torch.matmul(weights, v)
return output, weights
# 自注意力机制 (Self Attention)
class SelfAttention(nn.Module):
def __init__(self, d_model: int = 512) -> None:
"""
自注意力机制是注意力机制的一种特殊形式,其中 Q、K、V 都来自同一输入序列,
其能够捕捉序列内部元素之间的关系,不依赖 RNN 或 CNN,直接建模长距离依赖。
Args:
d_model: 特征维度,默认为 512
"""
super(SelfAttention, self).__init__()
self.attention = Attention() # 注意力机制
# 合并 Q、K、V 的线性变换
self.linear_qkv = nn.Linear(d_model, d_model * 3)
self.linear_out = nn.Linear(d_model, d_model)
def forward(self,
x: Tensor,
padding_mask: Tensor = None,
attn_mask: Tensor = None) -> Tensor:
"""
_
Args:
x: 词嵌入 shape=(batch_size, seq_len, d_model)
padding_mask: 填充掩码 shape=(batch_size, seq_len)
attn_mask: 因果掩码 shape=(seq_len, seq_len)
Returns:
output: 输出 shape=(batch_size, seq_len, d_model)
"""
# 通过线性层同时生成 Q、K、V
qkv = self.linear_qkv(x)
q, k, v = qkv.chunk(3, dim=-1) # (batch_size, seq_len, d_model)
# 应用注意力机制
output, weights = self.attention(q, k, v, padding_mask, attn_mask)
return self.linear_out(output)
# 多头自注意力机制 (Multi-Head Self Attention)
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int = 512, num_heads: int = 8) -> None:
"""
多头自注意力机制是自注意力机制的扩展,通过将输入特征分割成多个头,
每个头独立计算注意力,然后将结果拼接起来,从而提高模型的多角度表达能力。
Args:
d_model: 特征维度,默认为 512
num_heads: 头数,默认为 8
"""
super(MultiHeadSelfAttention, self).__init__()
if d_model % num_heads != 0:
raise ValueError(f"d_model must be divisible by num_heads, but got {d_model} and {num_heads}")
self.num_heads = num_heads
self.attention = Attention() # 注意力机制
# 分别对 Q、K、V 进行线性变换
self.linear_q = nn.Linear(d_model, d_model)
self.linear_k = nn.Linear(d_model, d_model)
self.linear_v = nn.Linear(d_model, d_model)
self.linear_out = nn.Linear(d_model, d_model)
def forward(self,
q: Tensor,
k: Tensor,
v: Tensor,
padding_mask: Tensor = None,
attn_mask: Tensor = None) -> Tensor:
"""
Q、K、V 在不同的自注意力模块中的来源可能不同,
在编解码器自注意力中,Q 来自解码器的输入,K、V 来自编码器的输出。
Args:
q: 查询 shape=(batch_size, seq_len, d_model)
k: 键 shape=(batch_size, seq_len / seq_len_k, d_model)
v: 值 shape=(batch_size, seq_len / seq_len_k, d_model)
padding_mask: 填充掩码 shape=(batch_size, seq_len / seq_len_k)
attn_mask: 因果掩码 shape=(seq_len / seq_len_k, seq_len / seq_len_k)
Returns:
output: 输出 shape=(batch_size, seq_len, d_model)
"""
q = self.linear_q(q)
k = self.linear_k(k)
v = self.linear_v(v)
batch_size, seq_len, seq_len_k = q.size(0), q.size(1), k.size(1)
# (batch_size, num_heads, seq_len, d_k)
q = q.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
k = k.view(batch_size, seq_len_k, self.num_heads, -1).transpose(1, 2)
v = v.view(batch_size, seq_len_k, self.num_heads, -1).transpose(1, 2)
# 调整掩码形状以匹配多头
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(1) # (batch_size, 1, seq_len)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0) # (1, seq_len, seq_len)
# 应用注意力机制
output, weights = self.attention(q, k, v, padding_mask, attn_mask)
# 拼接
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.linear_out(output)
# 残差连接和层归一化 (Add&Norm)
class AddNorm(nn.Module):
def __init__(self, d_model: int = 512) -> None:
"""
Add&Norm 层结合了两种操作:残差连接和层归一化,
可以使模型在训练过程中更加稳定,并且通过堆叠多个这样的层来构建更深的模型。
Args:
d_model: 特征维度,默认为 512
"""
super(AddNorm, self).__init__()
self.norm = nn.LayerNorm(d_model) # 层归一化
def forward(self, x: Tensor, y: Tensor) -> Tensor:
return self.norm(x + y)
# 前馈层 (FeedForward Layer)
class FeedForward(nn.Module):
def __init__(self,
input_dim: int = 512,
hidden_dim: int = 2048,
activation: str = "relu",
dropout: float = 0.1) -> None:
"""
全连接层(扩大维度) -> 激活函数 -> 全连接层(恢复原始维度)
通过非线性变换进一步提取和增强特征,使模型具备更强的模式识别和语义组合能力。
Args:
input_dim: 输入维度,默认为 512
hidden_dim: 隐藏层维度,默认为 2048
activation: 激活函数,默认为 "relu"
-支持:"sigmoid", "tanh", "relu", "gelu", "leaky_relu", "elu"
dropout: 丢弃率,默认为 0.1
"""
super(FeedForward, self).__init__()
match activation: # 切换不同的激活函数
case "sigmoid":
activation = nn.Sigmoid()
case "tanh":
activation = nn.Tanh()
case "relu":
activation = nn.ReLU()
case "gelu":
activation = nn.GELU()
case "leaky_relu":
activation = nn.LeakyReLU()
case "elu":
activation = nn.ELU()
case _:
raise ValueError(f"Unsupported activation function: {activation}")
# Linear -> activation -> Dropout -> Linear
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
activation,
nn.Dropout(dropout),
nn.Linear(hidden_dim, input_dim),
)
def forward(self, x: Tensor) -> Tensor:
return self.ffn(x)
# 编码层
class EncoderLayer(nn.Module):
def __init__(self,
d_model: int = 512,
num_heads: int = 8,
dim_feedforward: int = 2048,
dropout: float = 0.1) -> None:
"""
MultiHeadSelfAttention -> AddNorm -> FeedForward -> AddNorm
Args:
d_model: 特征维度,默认为 512
num_heads: 头数,默认为 8
dim_feedforward: FFN 隐藏层维度,默认为 2048
dropout: 丢弃率,默认为 0.1
"""
super(EncoderLayer, self).__init__()
# 多头自注意力层
self.attn = MultiHeadSelfAttention(d_model, num_heads)
# Add&Norm 层
self.norm1 = AddNorm(d_model)
self.norm2 = AddNorm(d_model)
# 前馈层
self.ffn = FeedForward(d_model, dim_feedforward, dropout=dropout)
# 丢弃层
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self,
x: Tensor,
padding_mask: Tensor = None,
attn_mask: Tensor = None) -> Tensor:
x = self.norm1(x, self.dropout1(self.attn(x, x, x, padding_mask, attn_mask)))
x = self.norm2(x, self.dropout2(self.ffn(x)))
return x
# 编码器
class Encoder(nn.Module):
def __init__(self, num_layers: int = 6, **params) -> None:
"""
编码器由多个编码层组成,每个编码层结构相同但并不共享参数。
Args:
num_layers: 层数,默认为 6
**params: 编码层参数,参考 EncoderLayer
"""
super(Encoder, self).__init__()
self.layers = nn.ModuleList([
EncoderLayer(**params)
for _ in range(num_layers)
])
def forward(self,
x: Tensor,
padding_mask: Tensor = None,
attn_mask: Tensor = None) -> Tensor:
for layer in self.layers: # 逐层传递
x = layer(x, padding_mask, attn_mask)
return x
# 解码层
class DecoderLayer(nn.Module):
def __init__(self,
d_model: int = 512,
num_heads: int = 8,
dim_feedforward: int = 2048,
dropout: float = 0.1) -> None:
"""
MultiHeadSelfAttention -> AddNorm -> MultiHeadSelfAttention -> AddNorm -> FeedForward -> AddNorm
Args:
d_model: 特征维度,默认为 512
num_heads: 头数,默认为 8
dim_feedforward: FFN 隐藏层维度,默认为 2048
dropout: 丢弃率,默认为 0.1
"""
super(DecoderLayer, self).__init__()
# 多头自注意力层
self.attn = MultiHeadSelfAttention(d_model, num_heads)
self.cross_attn = MultiHeadSelfAttention(d_model, num_heads)
# Add&Norm 层
self.norm1 = AddNorm(d_model)
self.norm2 = AddNorm(d_model)
self.norm3 = AddNorm(d_model)
# 前馈层
self.ffn = FeedForward(d_model, dim_feedforward, dropout=dropout)
# 丢弃层
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self,
y: Tensor,
memory: Tensor,
padding_mask_y: Tensor = None,
padding_mask_memory: Tensor = None,
attn_mask_y: Tensor = None,
attn_mask_memory: Tensor = None) -> None:
x = y
x = self.norm1(x, self.dropout1(self.attn(x, x, x, padding_mask_y, attn_mask_y)))
x = self.norm2(x, self.dropout2(self.attn(x, memory, memory, padding_mask_memory, attn_mask_memory)))
x = self.norm3(x, self.dropout3(self.ffn(x)))
return x
# 解码器
class Decoder(nn.Module):
def __init__(self, num_layers: int = 6, **params) -> None:
"""
解码器由多个解码层组成,每个解码层结构相同但并不共享参数。
Args:
num_layers: 层数,默认为 6
**params: 解码层参数,参考 DecoderLayer
"""
super(Decoder, self).__init__()
self.layers = nn.ModuleList([
DecoderLayer(**params)
for _ in range(num_layers)
])
def forward(self,
y: Tensor,
memory: Tensor,
padding_mask_y: Tensor = None,
padding_mask_memory: Tensor = None,
attn_mask_y: Tensor = None,
attn_mask_memory: Tensor = None) -> Tensor:
x = y
for layer in self.layers: # 逐层传递
x = layer(y, memory, padding_mask_y, padding_mask_memory, attn_mask_y, attn_mask_memory)
return x
# Transformer
class Transformer(nn.Module):
def __init__(self,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
**params) -> None:
"""
transformer 是标准的编码器-解码器结构
Args:
num_encoder_layers: 编码器层数,默认为 6
num_decoder_layers: 解码器层数,默认为 6
**params: 编解码层参数,参考 EncoderLayer 和 DecoderLayer
"""
super(Transformer, self).__init__()
self.encoder = Encoder(num_encoder_layers, **params) # 编码器
self.decoder = Decoder(num_decoder_layers, **params) # 解码器
def forward(self,
x: Tensor,
y: Tensor,
padding_mask_x: Tensor = None,
padding_mask_y: Tensor = None,
padding_mask_memory: Tensor = None,
attn_mask_x: Tensor = None,
attn_mask_y: Tensor = None,
attn_mask_memory: Tensor = None) -> Tensor:
memory = self.encoder(x, padding_mask_x, attn_mask_x)
output = self.decoder(y, memory, padding_mask_y, padding_mask_memory, attn_mask_y, attn_mask_memory)
return output
if __name__ == '__main__':
# attention = Attention(True)
# t1, t2 = attention(
# torch.rand((2, 3, 64)),
# torch.rand((2, 5, 64)),
# torch.rand((2, 5, 512)),
# torch.tensor([[False, True, True, True, True],
# [False, False, False, False, True]])
# )
# print(t1.shape, t2.shape)
# self_attention = SelfAttention()
# t3 = self_attention(
# torch.rand((2, 5, 512)),
# torch.tensor([[False, False, False, True, True],
# [False, False, True, True, True]])
# )
# print(t3.shape)
# multi_head_self_attention = MultiHeadSelfAttention(num_heads=2)
# t4 = multi_head_self_attention(
# torch.rand((2, 3, 512)),
# torch.rand((2, 5, 512)),
# torch.rand((2, 5, 512)),
# torch.tensor([[False, False, False, True, True],
# [False, False, True, True, True]])
# )
# print(t4.shape)
# encoder_layer = EncoderLayer()
# t5 = encoder_layer(
# torch.rand((2, 5, 512)),
# torch.tensor([[False, False, False, True, True],
# [False, False, True, True, True]])
# )
# print(t5.shape)
# encoder = Encoder(dropout=0.2)
# t6 = encoder(
# torch.rand((2, 5, 512)),
# torch.tensor([[False, False, False, True, True],
# [False, False, True, True, True]])
# )
# print(t6.shape)
# decoder_layer = DecoderLayer()
# t7 = decoder_layer(
# torch.rand((2, 3, 512)),
# torch.rand((2, 5, 512)),
# torch.tensor([[False, False, False],
# [False, False, True]]),
# torch.tensor([[False, False, False, True, True],
# [False, False, True, True, True]])
# )
# print(t7.shape)
# decoder = Decoder()
# t8 = decoder(
# torch.rand((2, 3, 512)),
# torch.rand((2, 5, 512)),
# torch.tensor([[False, False, False],
# [False, False, True]]),
# torch.tensor([[False, False, False, True, True],
# [False, False, True, True, True]])
# )
# print(t8.shape)
transformer = Transformer()
t9 = transformer(
torch.rand((2, 5, 512)),
torch.rand((2, 3, 512)),
torch.tensor([[False, False, False, True, True],
[False, False, True, True, True]]),
torch.tensor([[False, False, False],
[False, False, True]]),
)
print(t9.shape)