在上一篇文章中,我们探讨了分布式训练实战。本文将深入解析注意力机制的完整发展历程,从最初的Seq2Seq模型到革命性的Transformer架构。我们将使用PyTorch实现2个关键阶段的注意力机制变体,并在机器翻译任务上进行对比实验。
一、注意力机制演进路线
1. 关键模型对比
模型 | 发表年份 | 核心创新 | 计算复杂度 | 典型应用 |
---|---|---|---|---|
Seq2Seq | 2014 | 编码器-解码器架构 | O(n²) | 机器翻译 |
Bahdanau Attention | 2015 | 软注意力机制(动态上下文向量) | O(n²) | 文本生成、语音识别 |
Luong Attention | 2015 | 全局/局部注意力(改进对齐方式) | O(n²) | 语音识别、长文本翻译 |
Transformer | 2017 | 自注意力机制(并行化处理) | O(n²) | 所有序列任务(NLP/CV) |
Sparse Transformer | 2019 | 稀疏注意力(分块处理长序列) | O(n√n) | 长文本生成、基因序列分析 |
MQA | 2023 | 多查询注意力(共享KV减少内存) | O(n log n) | 大模型推理加速 |
GQA | 2024 | 分组查询注意力(平衡精度与效率) | O(n log n) | 工业级大模型部署 |
Flash Attention | 2024 | 分块计算优化KV缓存 | O(n√n) | 超长序列处理(>10k tokens) |
DeepSeek MLA | 2025 | 多头潜在注意力(潜在空间投影) | O(n log n) | 多模态融合、复杂推理任务 |
TPA | 2025 | 张量积分解注意力(动态秩优化) | O(n) | 边缘计算、低资源环境 |
MoBA | 2025 | 混合块注意力(Top-K门控选择) | O(n log n) | 百万级长文本处理 |
ECA | 2025 | 高效通道注意力(参数无关门控) | O(1) | 图像分类、目标检测 |
2. 注意力类型分类
class AttentionTypes:
def __init__(self):
self.soft_attention = ["加性注意力", "点积注意力"]
self.hard_attention = ["随机硬注意力", "最大似然注意力"]
self.self_attention = ["标准自注意力", "稀疏自注意力"]
self.cross_attention = ["编码器-解码器注意力"]
二、基础注意力机制实现
1. 环境配置
pip install torch matplotlib
2. Luong注意力实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import random
from tqdm import tqdm
# 设备配置 - 检查是否有可用的GPU,没有则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据预处理部分
def build_tokenizer(text_iter, vocab_size=20000):
"""构建分词器
参数:
text_iter: 文本迭代器
vocab_size: 词汇表大小
返回:
训练好的分词器
"""
# 使用Unigram模型初始化分词器
tokenizer = Tokenizer(models.Unigram())
# 使用空格作为预分词器
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
# 配置训练器
trainer = trainers.UnigramTrainer(
vocab_size=vocab_size,
special_tokens=["[PAD]", "[UNK]", "[SOS]", "[EOS]"], # 特殊标记
unk_token="[UNK]" # 显式设置UNK标记
)
# 从文本迭代器训练分词器
tokenizer.train_from_iterator(text_iter, trainer)
# 确保UNK标记在分词器中正确设置
if tokenizer.token_to_id("[UNK]") is None:
raise ValueError("UNK token not properly initialized in tokenizer")
return tokenizer
class TranslationDataset(Dataset):
"""翻译数据集类
用于加载和处理翻译数据
"""
def __init__(self, data, src_tokenizer, trg_tokenizer, max_len=100):
"""初始化
参数:
data: 原始数据
src_tokenizer: 源语言分词器
trg_tokenizer: 目标语言分词器
max_len: 最大序列长度
"""
self.data = data
self.src_tokenizer = src_tokenizer
self.trg_tokenizer = trg_tokenizer
self.max_len = max_len
# 获取UNK标记的ID
self.src_unk_id = self.src_tokenizer.token_to_id("[UNK]")
self.trg_unk_id = self.trg_tokenizer.token_to_id("[UNK]")
if self.src_unk_id is None or self.trg_unk_id is None:
raise ValueError("Tokenizers must have [UNK] token")
def __len__(self):
"""返回数据集大小"""
return len(self.data)
def __getitem__(self, idx):
"""获取单个样本
参数:
idx: 索引
返回:
包含源语言和目标语言token ID的字典
"""
item = self.data[idx]["translation"]
# 源语言(中文)处理
src_encoded = self.src_tokenizer.encode(item["zh"])
# 添加开始和结束标记,并截断到最大长度
src_tokens = ["[SOS]"] + src_encoded.tokens[:self.max_len - 2] + ["[EOS]"]
# 将token转换为ID,未知token使用UNK ID
src_ids = [self.src_tokenizer.token_to_id(t) or self.src_unk_id for t in src_tokens]
# 目标语言(英文)处理
trg_encoded = self.trg_tokenizer.encode(item["en"])
trg_tokens = ["[SOS]"] + trg_encoded.tokens[:self.max_len - 2] + ["[EOS]"]
trg_ids = [self.trg_tokenizer.token_to_id(t) or self.trg_unk_id for t in trg_tokens]
return {
"src": torch.tensor(src_ids),
"trg": torch.tensor(trg_ids)
}
def collate_fn(batch):
"""批处理函数
用于DataLoader中对批次数据进行填充
"""
src = [item["src"] for item in batch]
trg = [item["trg"] for item in batch]
return {
"src": nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0), # 用0填充
"trg": nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)
}
# 模型实现部分
class Encoder(nn.Module):
"""编码器
将输入序列编码为隐藏状态
"""
def __init__(self, input_dim, emb_dim, hid_dim, n_layers=1, dropout=0.1):
"""初始化
参数:
input_dim: 输入维度(词汇表大小)
emb_dim: 词嵌入维度
hid_dim: 隐藏层维度
n_layers: RNN层数
dropout: dropout率
"""
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=0) # 词嵌入层
# 双向GRU
self.rnn = nn.GRU(emb_dim, hid_dim, n_layers,
dropout=dropout if n_layers > 1 else 0,
bidirectional=True)
self.fc = nn.Linear(hid_dim * 2, hid_dim) # 用于合并双向输出的全连接层
self.dropout = nn.Dropout(dropout)
self.n_layers = n_layers
def forward(self, src):
"""前向传播
参数:
src: 输入序列
返回:
outputs: 编码器所有时间步的输出
hidden: 最后一个时间步的隐藏状态
"""
# 词嵌入 + dropout
embedded = self.dropout(self.embedding(src)) # [batch_size, src_len, emb_dim]
# GRU处理 (需要将batch维度放在第二位)
outputs, hidden = self.rnn(embedded.transpose(0, 1)) # outputs: [src_len, batch_size, hid_dim*2]
# 处理双向隐藏状态
# hidden的形状是[num_layers * num_directions, batch_size, hid_dim]
hidden = hidden.view(self.n_layers, 2, -1, self.rnn.hidden_size) # [n_layers, 2, batch_size, hid_dim]
hidden = hidden[-1] # 取最后一层 [2, batch_size, hid_dim]
hidden = torch.cat([hidden[0], hidden[1]], dim=1) # 合并双向输出 [batch_size, hid_dim*2]
hidden = torch.tanh(self.fc(hidden)) # [batch_size, hid_dim]
# 扩展以匹配解码器的层数
hidden = hidden.unsqueeze(0).repeat(self.n_layers, 1, 1) # [n_layers, batch_size, hid_dim]
return outputs, hidden
class LuongAttention(nn.Module):
def __init__(self, hid_dim, method="general"):
super().__init__()
self.method = method
if method == "general":
self.W = nn.Linear(hid_dim, hid_dim, bias=False)
elif method == "concat":
self.W = nn.Linear(hid_dim * 2, hid_dim, bias=False)
self.v = nn.Linear(hid_dim, 1, bias=False)
def forward(self, decoder_hidden, encoder_outputs):
""" decoder_hidden: [1, batch_size, hid_dim]
encoder_outputs: [src_len, batch_size, hid_dim * 2] (bidirectional)"""
if self.method == "dot":
# 处理双向输出 - 取前向和后向的平均
hid_dim = decoder_hidden.size(-1)
encoder_outputs = encoder_outputs.view(encoder_outputs.size(0), encoder_outputs.size(1), 2, hid_dim)
encoder_outputs = encoder_outputs.mean(dim=2) # [src_len, batch_size, hid_dim]
# 计算点积分数
scores = torch.matmul(
encoder_outputs.transpose(0, 1), # [batch_size, src_len, hid_dim]
decoder_hidden.transpose(0, 1).transpose(1, 2) # [batch_size, hid_dim, 1]
).squeeze(2) # [batch_size, src_len]
# 添加缩放因子
scores = scores / (decoder_hidden.size(-1) ** 0.5)
elif self.method == "general":
# 对于通用注意力,我们需要对解码器隐藏状态进行投影
decoder_hidden_proj = self.W(decoder_hidden) # [1, batch_size, hid_dim]
# 处理双向输出
hid_dim = decoder_hidden.size(-1)
encoder_outputs = encoder_outputs.view(encoder_outputs.size(0), encoder_outputs.size(1), 2, hid_dim)
encoder_outputs = encoder_outputs.mean(dim=2) # [src_len, batch_size, hid_dim]
scores = torch.matmul(
encoder_outputs.transpose(0, 1), # [batch_size, src_len, hid_dim]
decoder_hidden_proj.transpose(0, 1).transpose(1, 2) # [batch_size, hid_dim, 1]
).squeeze(2) # [batch_size, src_len]
elif self.method == "concat":
# 对于concat,我们可以使用完整的双向输出
decoder_hidden = decoder_hidden.repeat(encoder_outputs.size(0), 1, 1) # [src_len, batch_size, hid_dim]
energy = torch.cat((decoder_hidden, encoder_outputs), dim=2) # [src_len, batch_size, hid_dim*3]
scores = self.v(torch.tanh(self.W(energy))).squeeze(2).t() # [batch_size, src_len]
attn_weights = F.softmax(scores, dim=1)
# 对于上下文向量,使用原始双向输出
context = torch.bmm(
attn_weights.unsqueeze(1), # [batch_size, 1, src_len]
encoder_outputs.transpose(0, 1) # [batch_size, src_len, hid_dim*2]
).squeeze(1) # [batch_size, hid_dim*2]
return context, attn_weights
class Decoder(nn.Module):
"""解码器
使用注意力机制生成目标序列
"""
def __init__(self, output_dim, emb_dim, hid_dim, n_layers=1, dropout=0.1, attn_method="general"):
"""初始化
参数:
output_dim: 输出维度(目标词汇表大小)
emb_dim: 词嵌入维度
hid_dim: 隐藏层维度
n_layers: RNN层数
dropout: dropout率
attn_method: 注意力计算方法
"""
super().__init__()
self.output_dim = output_dim
self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=0) # 词嵌入层
# 单向GRU
self.rnn = nn.GRU(emb_dim, hid_dim, n_layers,
dropout=dropout if n_layers > 1 else 0)
self.attention = LuongAttention(hid_dim, attn_method) # 注意力层
# 全连接层(根据注意力方法调整输入维度)
self.fc = nn.Linear(hid_dim * 3 if attn_method == "concat" else hid_dim * 2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, encoder_outputs):
"""前向传播
参数:
input: 当前输入token
hidden: 当前隐藏状态
encoder_outputs: 编码器输出
返回:
prediction: 预测的下一个token
hidden: 新的隐藏状态
attn_weights: 注意力权重
"""
input = input.unsqueeze(0) # [1, batch_size]
embedded = self.dropout(self.embedding(input)) # [1, batch_size, emb_dim]
# GRU处理
output, hidden = self.rnn(embedded, hidden) # output: [1, batch_size, hid_dim]
# 计算注意力
context, attn_weights = self.attention(output, encoder_outputs)
# 预测下一个token(拼接RNN输出和上下文向量)
prediction = self.fc(torch.cat((output.squeeze(0), context), dim=1))
return prediction, hidden, attn_weights
class Seq2Seq(nn.Module):
"""序列到序列模型
整合编码器和解码器
"""
def __init__(self, encoder, decoder, device):
"""初始化
参数:
encoder: 编码器实例
decoder: 解码器实例
device: 计算设备
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
# 确保解码器与编码器层数相同
assert decoder.rnn.num_layers == encoder.n_layers
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5):
"""前向传播
参数:
src: 源序列
trg: 目标序列(训练时使用)
teacher_forcing_ratio: 教师强制比例
返回:
所有时间步的输出
"""
batch_size = src.shape[0]
trg_len = trg.shape[1]
trg_vocab_size = self.decoder.output_dim
# 初始化输出张量
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
# 编码器处理
encoder_outputs, hidden = self.encoder(src)
# 第一个输入是<SOS>标记
input = trg[:, 0]
# 逐步生成输出序列
for t in range(1, trg_len):
# 解码器处理
output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
outputs[t] = output
# 决定是否使用教师强制
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[:, t] if teacher_force else top1
return outputs
# 训练与评估函数
def train(model, loader, optimizer, criterion, clip):
"""训练函数
参数:
model: 模型
loader: 数据加载器
optimizer: 优化器
criterion: 损失函数
clip: 梯度裁剪阈值
返回:
平均损失
"""
model.train()
epoch_loss = 0
for batch in tqdm(loader, desc="Training"):
src = batch["src"].to(device)
trg = batch["trg"].to(device)
optimizer.zero_grad()
output = model(src, trg)
# 计算损失(忽略第一个token)
output = output[1:].reshape(-1, output.shape[-1])
trg = trg[:, 1:].reshape(-1)
loss = criterion(output, trg)
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(loader)
def evaluate(model, loader, criterion):
"""评估函数
参数:
model: 模型
loader: 数据加载器
criterion: 损失函数
返回:
平均损失
"""
model.eval()
epoch_loss = 0
with torch.no_grad():
for batch in tqdm(loader, desc="Evaluating"):
src = batch["src"].to(device)
trg = batch["trg"].to(device)
# 评估时不使用教师强制
output = model(src, trg, teacher_forcing_ratio=0)
output = output[1:].reshape(-1, output.shape[-1])
trg = trg[:, 1:].reshape(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(loader)
# 加载数据(使用opus100数据集的中英翻译部分,只取前10000条作为示例)
dataset = load_dataset("./opus100", "en-zh", split="train[:10000]")
# 划分训练集和验证集
train_val = dataset.train_test_split(test_size=0.2)
# 构建分词器
def get_text_iter(data, lang="zh"):
"""获取文本迭代器
用于分词器训练
"""
for item in data["translation"]:
yield item[lang]
# 训练中文和英文分词器
zh_tokenizer = build_tokenizer(get_text_iter(train_val["train"]))
en_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "en"))
# 创建DataLoader
train_dataset = TranslationDataset(train_val["train"], zh_tokenizer, en_tokenizer)
val_dataset = TranslationDataset(train_val["test"], zh_tokenizer, en_tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)
# 初始化模型
INPUT_DIM = len(zh_tokenizer.get_vocab()) # 中文词汇表大小
OUTPUT_DIM = len(en_tokenizer.get_vocab()) # 英文词汇表大小
ENC_EMB_DIM = 512 # 编码器词嵌入维度
DEC_EMB_DIM = 512 # 解码器词嵌入维度
HID_DIM = 1024 # 隐藏层维度
N_LAYERS = 3 # RNN层数
DROP_RATE = 0.3 # dropout率
# 创建编码器、解码器和seq2seq模型
encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, DROP_RATE)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DROP_RATE, "dot")
model = Seq2Seq(encoder, decoder, device).to(device)
# 训练配置
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # 学习率调度器
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略填充标记的损失
CLIP = 5.0 # 梯度裁剪阈值
N_EPOCHS = 20 # 训练轮数
# 训练循环
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
train_loss = train(model, train_loader, optimizer, criterion, CLIP)
valid_loss = evaluate(model, val_loader, criterion)
# 保存最佳模型
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'best_model.pt')
print(f'Epoch: {epoch + 1:02}')
print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')
def translate(model, sentence, src_tokenizer, trg_tokenizer, max_len=50):
"""翻译函数
参数:
model: 训练好的模型
sentence: 待翻译的句子
src_tokenizer: 源语言分词器
trg_tokenizer: 目标语言分词器
max_len: 最大生成长度
返回:
翻译结果
"""
model.eval()
# 中文分词并编码
tokens = ["[SOS]"] + src_tokenizer.encode(sentence).tokens + ["[EOS]"]
src = torch.tensor([src_tokenizer.token_to_id(t) for t in tokens]).unsqueeze(0).to(device)
# 初始化目标序列(以<SOS>开始)
trg_indexes = [trg_tokenizer.token_to_id("[SOS]")]
# 逐步生成目标序列
for i in range(max_len):
trg_tensor = torch.tensor(trg_indexes).unsqueeze(0).to(device)
with torch.no_grad():
output = model(src, trg_tensor)
# 获取预测的下一个token
pred_token = output.argmax(2)[-1].item()
trg_indexes.append(pred_token)
# 如果遇到<EOS>则停止
if pred_token == trg_tokenizer.token_to_id("[EOS]"):
break
# 将ID转换为token
trg_tokens = [trg_tokenizer.id_to_token(i) for i in trg_indexes]
# 去掉<EOS>和<SOS>并返回
return ' '.join(trg_tokens[1:-1])
# 测试翻译
test_sentences = [
"你好世界",
"深度学习很有趣",
"今天天气真好"
]
print("\n测试翻译结果:")
for sent in test_sentences:
translation = translate(model, sent, zh_tokenizer, en_tokenizer)
print(f"中文: {sent} -> 英文: {translation}")
输出为:
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 01
Train Loss: 6.443 | Val. Loss: 6.351
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 02
Train Loss: 6.315 | Val. Loss: 6.400
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.53it/s]
Epoch: 03
Train Loss: 6.307 | Val. Loss: 6.406
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.52it/s]
Epoch: 04
Train Loss: 6.303 | Val. Loss: 6.469
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:26<00:00, 1.21it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 05
Train Loss: 6.304 | Val. Loss: 6.398
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 06
Train Loss: 6.298 | Val. Loss: 6.421
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:27<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.53it/s]
Epoch: 07
Train Loss: 6.298 | Val. Loss: 6.459
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.52it/s]
Epoch: 08
Train Loss: 6.291 | Val. Loss: 6.425
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 09
Train Loss: 6.293 | Val. Loss: 6.425
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 10
Train Loss: 6.293 | Val. Loss: 6.491
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:26<00:00, 1.21it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 11
Train Loss: 6.294 | Val. Loss: 6.467
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:27<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 12
Train Loss: 6.295 | Val. Loss: 6.439
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:32<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 13
Train Loss: 6.293 | Val. Loss: 6.495
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:30<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 14
Train Loss: 6.296 | Val. Loss: 6.471
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 15
Train Loss: 6.300 | Val. Loss: 6.423
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:30<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 16
Train Loss: 6.298 | Val. Loss: 6.458
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 17
Train Loss: 6.303 | Val. Loss: 6.510
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 18
Train Loss: 6.305 | Val. Loss: 6.479
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 19
Train Loss: 6.309 | Val. Loss: 6.585
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.57it/s]
Epoch: 20
Train Loss: 6.310 | Val. Loss: 6.515
测试翻译结果:
中文: 你好世界 -> 英文: [PAD]
中文: 深度学习很有趣 -> 英文: [PAD]
中文: 今天天气真好 -> 英文: [PAD]
三、Transformer实现
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import math
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据预处理
def build_tokenizer(text_iter, vocab_size=20000):
tokenizer = Tokenizer(models.Unigram())
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
trainer = trainers.UnigramTrainer(
vocab_size=vocab_size,
special_tokens=["[PAD]", "[UNK]", "[SOS]", "[EOS]"], # 确保包含UNK
unk_token="[UNK]" # 显式指定UNK token
)
tokenizer.train_from_iterator(text_iter, trainer)
return tokenizer
def get_text_iter(dataset, language="zh"):
for item in dataset["translation"]:
yield item[language]
# 加载数据集
train_data = dataset = load_dataset("opus100", "en-zh")
train_val = train_data["train"].train_test_split(test_size=0.2)
# 构建中英文分词器
zh_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "zh"))
en_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "en"))
# 词汇表
zh_vocab = zh_tokenizer.get_vocab()
en_vocab = en_tokenizer.get_vocab()
class TranslationDataset(Dataset):
def __init__(self, data, src_tokenizer, trg_tokenizer, max_len=50):
self.data = data
self.src_tokenizer = src_tokenizer
self.trg_tokenizer = trg_tokenizer
self.max_len = max_len
# 获取所有必须的ID(确保不为None)
self.src_unk_id = src_tokenizer.token_to_id("[UNK]")
self.src_sos_id = src_tokenizer.token_to_id("[SOS]")
self.src_eos_id = src_tokenizer.token_to_id("[EOS]")
self.trg_unk_id = trg_tokenizer.token_to_id("[UNK]")
self.trg_sos_id = trg_tokenizer.token_to_id("[SOS]")
self.trg_eos_id = trg_tokenizer.token_to_id("[EOS]")
# 验证关键ID是否存在
self._validate_token_ids()
def _validate_token_ids(self):
for name, token_id in [("SRC_UNK", self.src_unk_id),
("SRC_SOS", self.src_sos_id),
("SRC_EOS", self.src_eos_id),
("TRG_UNK", self.trg_unk_id),
("TRG_SOS", self.trg_sos_id),
("TRG_EOS", self.trg_eos_id)]:
if token_id is None:
raise ValueError(f"{name} token不存在于词汇表中")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]["translation"]
# 中文编码(源语言)
src_tokens = self._process_sequence(item["zh"], self.src_tokenizer,
self.src_sos_id, self.src_eos_id, self.src_unk_id)
# 英文编码(目标语言)
trg_tokens = self._process_sequence(item["en"], self.trg_tokenizer,
self.trg_sos_id, self.trg_eos_id, self.trg_unk_id)
return {
"src": torch.tensor(src_tokens),
"trg": torch.tensor(trg_tokens)
}
def _process_sequence(self, text, tokenizer, sos_id, eos_id, unk_id):
"""处理单个序列的编码"""
encoded = tokenizer.encode(text)
tokens = encoded.tokens[:self.max_len - 2] # 保留空间给SOS/EOS
# 转换为ID,确保没有None值
token_ids = []
for t in tokens:
token_id = tokenizer.token_to_id(t)
token_ids.append(token_id if token_id is not None else unk_id)
return [sos_id] + token_ids + [eos_id]
def collate_fn(batch):
src = [item["src"] for item in batch]
trg = [item["trg"] for item in batch]
return {
"src": nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0),
"trg": nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)
}
# 创建DataLoader
train_dataset = TranslationDataset(train_val["train"], zh_tokenizer, en_tokenizer)
val_dataset = TranslationDataset(train_val["test"], zh_tokenizer, en_tokenizer)
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
# Transformer模型实现
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=100):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_model // n_head
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 线性变换并分头
q = self.w_q(q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
k = self.w_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
v = self.w_v(v).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
# 计算输出
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
return self.w_o(output), attn
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_head, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask):
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout2(ffn_output))
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_head, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
self.cross_attn = MultiHeadAttention(d_model, n_head, dropout)
self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask, tgt_mask):
attn_output, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout1(attn_output))
attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout2(attn_output))
ffn_output = self.ffn(x)
x = self.norm3(x + self.dropout3(ffn_output))
return x
class Encoder(nn.Module):
def __init__(self, src_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len):
super().__init__()
self.token_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=0)
self.pos_embed = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask):
x = self.dropout(self.pos_embed(self.token_embed(src)))
for layer in self.layers:
x = layer(x, src_mask)
return x
class Decoder(nn.Module):
def __init__(self, trg_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len):
super().__init__()
self.token_embed = nn.Embedding(trg_vocab_size, d_model, padding_idx=0)
self.pos_embed = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)])
self.fc_out = nn.Linear(d_model, trg_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, trg, encoder_output, src_mask, tgt_mask):
x = self.dropout(self.pos_embed(self.token_embed(trg)))
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.fc_out(x)
class Transformer(nn.Module):
def __init__(self, src_vocab_size, trg_vocab_size, d_model=512, n_layers=6,
n_head=8, d_ff=2048, dropout=0.1, max_len=100):
super().__init__()
self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len)
self.decoder = Decoder(trg_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len)
self.src_pad_idx = 0
self.trg_pad_idx = 0
def make_src_mask(self, src):
return (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
def make_trg_mask(self, trg):
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
trg_len = trg.shape[1]
trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
return trg_pad_mask & trg_sub_mask
def forward(self, src, trg):
src_mask = self.make_src_mask(src)
tgt_mask = self.make_trg_mask(trg[:, :-1])
encoder_output = self.encoder(src, src_mask)
output = self.decoder(trg[:, :-1], encoder_output, src_mask, tgt_mask)
return output
# 训练与评估
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for batch in train_loader:
src = batch["src"].to(device)
trg = batch["trg"].to(device)
optimizer.zero_grad()
output = model(src, trg)
output_dim = output.shape[-1]
output = output.reshape(-1, output_dim)
trg = trg[:, 1:].reshape(-1)
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for batch in val_loader:
src = batch["src"].to(device)
trg = batch["trg"].to(device)
output = model(src, trg)
output_dim = output.shape[-1]
output = output.reshape(-1, output_dim)
trg = trg[:, 1:].reshape(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
# 初始化模型
model = Transformer(
src_vocab_size=len(zh_vocab),
trg_vocab_size=len(en_vocab),
d_model=256, # 减小模型尺寸便于快速训练
n_layers=3,
n_head=4
).to(device)
# 训练配置
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
CLIP = 1.0
N_EPOCHS = 20
# 训练循环
for epoch in range(N_EPOCHS):
train_loss = train(model, train_loader, optimizer, criterion, CLIP)
valid_loss = evaluate(model, val_loader, criterion)
print(f'Epoch: {epoch + 1:02}')
print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')
# 翻译测试函数
def translate(model, sentence, src_tokenizer, trg_tokenizer, max_len=50):
model.eval()
# 编码源语言
src_tokens = ["[SOS]"] + src_tokenizer.encode(sentence).tokens + ["[EOS]"]
src_ids = [src_tokenizer.token_to_id(t) for t in src_tokens]
src = torch.tensor(src_ids).unsqueeze(0).to(device) # [1, src_len]
# 初始化目标序列(始终以SOS开头)
trg_indexes = [trg_tokenizer.token_to_id("[SOS]")]
# 逐步解码
for _ in range(max_len):
trg_tensor = torch.tensor(trg_indexes).unsqueeze(0).to(device) # [1, trg_len]
with torch.no_grad():
output = model(src, trg_tensor) # 形状应为 [1, trg_len, vocab_size]
# 关键修正:安全获取最后一个预测token
if output.size(1) == 0: # 处理空输出情况
pred_token = trg_tokenizer.token_to_id("[UNK]")
else:
pred_token = output.argmax(-1)[0, -1].item() # 获取序列最后一个预测
trg_indexes.append(pred_token)
if pred_token == trg_tokenizer.token_to_id("[EOS]"):
break
# 转换为文本(跳过SOS和EOS)
trg_tokens = []
for i in trg_indexes[1:]: # 跳过初始的SOS
if i == trg_tokenizer.token_to_id("[EOS]"):
break
trg_tokens.append(trg_tokenizer.id_to_token(i))
return ' '.join(trg_tokens)
# 测试翻译
test_sentences = [
"你好世界",
"深度学习很有趣",
"今天天气真好"
]
print("\n测试翻译结果:")
for sent in test_sentences:
translation = translate(model, sent, zh_tokenizer, en_tokenizer)
print(f"中文: {sent} -> 英文: {translation}")
输出为:
Epoch: 01
Train Loss: 4.038 | Val. Loss: 3.263
Epoch: 02
Train Loss: 3.184 | Val. Loss: 2.786
Epoch: 03
Train Loss: 2.833 | Val. Loss: 2.497
Epoch: 04
Train Loss: 2.612 | Val. Loss: 2.323
Epoch: 05
Train Loss: 2.460 | Val. Loss: 2.205
Epoch: 06
Train Loss: 2.352 | Val. Loss: 2.123
Epoch: 07
Train Loss: 2.269 | Val. Loss: 2.055
Epoch: 08
Train Loss: 2.206 | Val. Loss: 2.009
Epoch: 09
Train Loss: 2.154 | Val. Loss: 1.971
Epoch: 10
Train Loss: 2.110 | Val. Loss: 1.936
Epoch: 11
Train Loss: 2.073 | Val. Loss: 1.910
Epoch: 12
Train Loss: 2.041 | Val. Loss: 1.886
Epoch: 13
Train Loss: 2.013 | Val. Loss: 1.866
Epoch: 14
Train Loss: 1.988 | Val. Loss: 1.847
Epoch: 15
Train Loss: 1.966 | Val. Loss: 1.833
Epoch: 16
Train Loss: 1.946 | Val. Loss: 1.820
Epoch: 17
Train Loss: 1.927 | Val. Loss: 1.805
Epoch: 18
Train Loss: 1.910 | Val. Loss: 1.793
Epoch: 19
Train Loss: 1.894 | Val. Loss: 1.786
Epoch: 20
Train Loss: 1.880 | Val. Loss: 1.773
测试翻译结果:
中文: 你好世界 -> 英文: [UNK] Hello THE , FUCK world . .
中文: 深度学习很有趣 -> 英文: [UNK] I of t f . unny
中文: 今天天气真好 -> 英文: [UNK] I weather t today o . n
四、注意力机制变体
1. 稀疏注意力实现
class SparseAttention(nn.Module):
def __init__(self, block_size=32):
super().__init__()
self.block_size = block_size
def forward(self, q, k, v):
batch_size, seq_len, d_model = q.shape
# 分块
q = q.view(batch_size, -1, self.block_size, d_model)
k = k.view(batch_size, -1, self.block_size, d_model)
v = v.view(batch_size, -1, self.block_size, d_model)
# 计算块内注意力
scores = torch.einsum('bind,bjnd->bnij', q, k) / math.sqrt(d_model)
attn = F.softmax(scores, dim=-1)
output = torch.einsum('bnij,bjnd->bind', attn, v)
# 恢复形状
output = output.view(batch_size, seq_len, d_model)
return output, attn
2. 相对位置编码
class RelativePositionEmbedding(nn.Module):
def __init__(self, max_len=512, d_model=512):
super().__init__()
self.emb = nn.Embedding(2 * max_len - 1, d_model)
self.max_len = max_len
def forward(self, q):
"""
q: [batch_size, seq_len, d_model]
"""
seq_len = q.size(1)
range_vec = torch.arange(seq_len)
distance_mat = range_vec[None, :] - range_vec[:, None] # [seq_len, seq_len]
distance_mat_clipped = torch.clamp(distance_mat + self.max_len - 1, 0, 2 * self.max_len - 2)
position_emb = self.emb(distance_mat_clipped) # [seq_len, seq_len, d_model]
return position_emb
五、性能对比与总结
1.注意力模式可视化
def plot_attention(attention, source, target):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
cax = ax.matshow(attention, cmap='bone')
ax.set_xticklabels([''] + source, rotation=90)
ax.set_yticklabels([''] + target)
plt.show()
2. 关键演进规律
信息瓶颈突破:从固定长度上下文到动态注意力分配
计算效率提升:从RNN的O(n)序列计算到Transformer的并行化
建模能力增强:从局部依赖到全局关系建模
在下一篇文章中,我们将深入探讨归一化技术对比(BN/LN/IN/GN),分析不同归一化方法的特点和适用场景。