PyTorch 深度学习实战(38):注意力机制全面解析(从Seq2Seq到Transformer)

发布于:2025-04-20 ⋅ 阅读:(18) ⋅ 点赞:(0)

在上一篇文章中,我们探讨了分布式训练实战。本文将深入解析注意力机制的完整发展历程,从最初的Seq2Seq模型到革命性的Transformer架构。我们将使用PyTorch实现2个关键阶段的注意力机制变体,并在机器翻译任务上进行对比实验。

一、注意力机制演进路线

1. 关键模型对比

模型 发表年份 核心创新 计算复杂度 典型应用
Seq2Seq 2014 编码器-解码器架构 O(n²) 机器翻译
Bahdanau Attention 2015 软注意力机制(动态上下文向量) O(n²) 文本生成、语音识别
Luong Attention 2015 全局/局部注意力(改进对齐方式) O(n²) 语音识别、长文本翻译
Transformer 2017 自注意力机制(并行化处理) O(n²) 所有序列任务(NLP/CV)
Sparse Transformer 2019 稀疏注意力(分块处理长序列) O(n√n) 长文本生成、基因序列分析
MQA 2023 多查询注意力(共享KV减少内存) O(n log n) 大模型推理加速
GQA 2024 分组查询注意力(平衡精度与效率) O(n log n) 工业级大模型部署
Flash Attention 2024 分块计算优化KV缓存 O(n√n) 超长序列处理(>10k tokens)
DeepSeek MLA 2025 多头潜在注意力(潜在空间投影) O(n log n) 多模态融合、复杂推理任务
TPA 2025 张量积分解注意力(动态秩优化) O(n) 边缘计算、低资源环境
MoBA 2025 混合块注意力(Top-K门控选择) O(n log n) 百万级长文本处理
ECA 2025 高效通道注意力(参数无关门控) O(1) 图像分类、目标检测

2. 注意力类型分类

class AttentionTypes:
    def __init__(self):
        self.soft_attention = ["加性注意力", "点积注意力"]
        self.hard_attention = ["随机硬注意力", "最大似然注意力"] 
        self.self_attention = ["标准自注意力", "稀疏自注意力"]
        self.cross_attention = ["编码器-解码器注意力"]

二、基础注意力机制实现

1. 环境配置

pip install torch  matplotlib

2. Luong注意力实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import random
from tqdm import tqdm

# 设备配置 - 检查是否有可用的GPU,没有则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 数据预处理部分
def build_tokenizer(text_iter, vocab_size=20000):
    """构建分词器
    参数:
        text_iter: 文本迭代器
        vocab_size: 词汇表大小
    返回:
        训练好的分词器
    """
    # 使用Unigram模型初始化分词器
    tokenizer = Tokenizer(models.Unigram())
    # 使用空格作为预分词器
    tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
    # 配置训练器
    trainer = trainers.UnigramTrainer(
        vocab_size=vocab_size,
        special_tokens=["[PAD]", "[UNK]", "[SOS]", "[EOS]"],  # 特殊标记
        unk_token="[UNK]"  # 显式设置UNK标记
    )
    # 从文本迭代器训练分词器
    tokenizer.train_from_iterator(text_iter, trainer)

    # 确保UNK标记在分词器中正确设置
    if tokenizer.token_to_id("[UNK]") is None:
        raise ValueError("UNK token not properly initialized in tokenizer")

    return tokenizer


class TranslationDataset(Dataset):
    """翻译数据集类
    用于加载和处理翻译数据
    """

    def __init__(self, data, src_tokenizer, trg_tokenizer, max_len=100):
        """初始化
        参数:
            data: 原始数据
            src_tokenizer: 源语言分词器
            trg_tokenizer: 目标语言分词器
            max_len: 最大序列长度
        """
        self.data = data
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer
        self.max_len = max_len

        # 获取UNK标记的ID
        self.src_unk_id = self.src_tokenizer.token_to_id("[UNK]")
        self.trg_unk_id = self.trg_tokenizer.token_to_id("[UNK]")

        if self.src_unk_id is None or self.trg_unk_id is None:
            raise ValueError("Tokenizers must have [UNK] token")

    def __len__(self):
        """返回数据集大小"""
        return len(self.data)

    def __getitem__(self, idx):
        """获取单个样本
        参数:
            idx: 索引
        返回:
            包含源语言和目标语言token ID的字典
        """
        item = self.data[idx]["translation"]

        # 源语言(中文)处理
        src_encoded = self.src_tokenizer.encode(item["zh"])
        # 添加开始和结束标记,并截断到最大长度
        src_tokens = ["[SOS]"] + src_encoded.tokens[:self.max_len - 2] + ["[EOS]"]
        # 将token转换为ID,未知token使用UNK ID
        src_ids = [self.src_tokenizer.token_to_id(t) or self.src_unk_id for t in src_tokens]

        # 目标语言(英文)处理
        trg_encoded = self.trg_tokenizer.encode(item["en"])
        trg_tokens = ["[SOS]"] + trg_encoded.tokens[:self.max_len - 2] + ["[EOS]"]
        trg_ids = [self.trg_tokenizer.token_to_id(t) or self.trg_unk_id for t in trg_tokens]

        return {
            "src": torch.tensor(src_ids),
            "trg": torch.tensor(trg_ids)
        }


def collate_fn(batch):
    """批处理函数
    用于DataLoader中对批次数据进行填充
    """
    src = [item["src"] for item in batch]
    trg = [item["trg"] for item in batch]
    return {
        "src": nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0),  # 用0填充
        "trg": nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)
    }


# 模型实现部分
class Encoder(nn.Module):
    """编码器
    将输入序列编码为隐藏状态
    """

    def __init__(self, input_dim, emb_dim, hid_dim, n_layers=1, dropout=0.1):
        """初始化
        参数:
            input_dim: 输入维度(词汇表大小)
            emb_dim: 词嵌入维度
            hid_dim: 隐藏层维度
            n_layers: RNN层数
            dropout: dropout率
        """
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=0)  # 词嵌入层
        # 双向GRU
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers,
                          dropout=dropout if n_layers > 1 else 0,
                          bidirectional=True)
        self.fc = nn.Linear(hid_dim * 2, hid_dim)  # 用于合并双向输出的全连接层
        self.dropout = nn.Dropout(dropout)
        self.n_layers = n_layers

    def forward(self, src):
        """前向传播
        参数:
            src: 输入序列
        返回:
            outputs: 编码器所有时间步的输出
            hidden: 最后一个时间步的隐藏状态
        """
        # 词嵌入 + dropout
        embedded = self.dropout(self.embedding(src))  # [batch_size, src_len, emb_dim]
        # GRU处理 (需要将batch维度放在第二位)
        outputs, hidden = self.rnn(embedded.transpose(0, 1))  # outputs: [src_len, batch_size, hid_dim*2]

        # 处理双向隐藏状态
        # hidden的形状是[num_layers * num_directions, batch_size, hid_dim]
        hidden = hidden.view(self.n_layers, 2, -1, self.rnn.hidden_size)  # [n_layers, 2, batch_size, hid_dim]
        hidden = hidden[-1]  # 取最后一层 [2, batch_size, hid_dim]
        hidden = torch.cat([hidden[0], hidden[1]], dim=1)  # 合并双向输出 [batch_size, hid_dim*2]
        hidden = torch.tanh(self.fc(hidden))  # [batch_size, hid_dim]

        # 扩展以匹配解码器的层数
        hidden = hidden.unsqueeze(0).repeat(self.n_layers, 1, 1)  # [n_layers, batch_size, hid_dim]

        return outputs, hidden


class LuongAttention(nn.Module):
    def __init__(self, hid_dim, method="general"):
        super().__init__()
        self.method = method
        if method == "general":
            self.W = nn.Linear(hid_dim, hid_dim, bias=False)
        elif method == "concat":
            self.W = nn.Linear(hid_dim * 2, hid_dim, bias=False)
            self.v = nn.Linear(hid_dim, 1, bias=False)

    def forward(self, decoder_hidden, encoder_outputs):
        """ decoder_hidden: [1, batch_size, hid_dim]
            encoder_outputs: [src_len, batch_size, hid_dim * 2] (bidirectional)"""

        if self.method == "dot":
            # 处理双向输出 - 取前向和后向的平均
            hid_dim = decoder_hidden.size(-1)
            encoder_outputs = encoder_outputs.view(encoder_outputs.size(0), encoder_outputs.size(1), 2, hid_dim)
            encoder_outputs = encoder_outputs.mean(dim=2)  # [src_len, batch_size, hid_dim]

            # 计算点积分数
            scores = torch.matmul(
                encoder_outputs.transpose(0, 1),  # [batch_size, src_len, hid_dim]
                decoder_hidden.transpose(0, 1).transpose(1, 2)  # [batch_size, hid_dim, 1]
            ).squeeze(2)  # [batch_size, src_len]

            # 添加缩放因子
            scores = scores / (decoder_hidden.size(-1) ** 0.5)

        elif self.method == "general":
            # 对于通用注意力,我们需要对解码器隐藏状态进行投影
            decoder_hidden_proj = self.W(decoder_hidden)  # [1, batch_size, hid_dim]

            # 处理双向输出
            hid_dim = decoder_hidden.size(-1)
            encoder_outputs = encoder_outputs.view(encoder_outputs.size(0), encoder_outputs.size(1), 2, hid_dim)
            encoder_outputs = encoder_outputs.mean(dim=2)  # [src_len, batch_size, hid_dim]

            scores = torch.matmul(
                encoder_outputs.transpose(0, 1),  # [batch_size, src_len, hid_dim]
                decoder_hidden_proj.transpose(0, 1).transpose(1, 2)  # [batch_size, hid_dim, 1]
            ).squeeze(2)  # [batch_size, src_len]

        elif self.method == "concat":
            # 对于concat,我们可以使用完整的双向输出
            decoder_hidden = decoder_hidden.repeat(encoder_outputs.size(0), 1, 1)  # [src_len, batch_size, hid_dim]
            energy = torch.cat((decoder_hidden, encoder_outputs), dim=2)  # [src_len, batch_size, hid_dim*3]
            scores = self.v(torch.tanh(self.W(energy))).squeeze(2).t()  # [batch_size, src_len]

        attn_weights = F.softmax(scores, dim=1)

        # 对于上下文向量,使用原始双向输出
        context = torch.bmm(
            attn_weights.unsqueeze(1),  # [batch_size, 1, src_len]
            encoder_outputs.transpose(0, 1)  # [batch_size, src_len, hid_dim*2]
        ).squeeze(1)  # [batch_size, hid_dim*2]

        return context, attn_weights


class Decoder(nn.Module):
    """解码器
    使用注意力机制生成目标序列
    """

    def __init__(self, output_dim, emb_dim, hid_dim, n_layers=1, dropout=0.1, attn_method="general"):
        """初始化
        参数:
            output_dim: 输出维度(目标词汇表大小)
            emb_dim: 词嵌入维度
            hid_dim: 隐藏层维度
            n_layers: RNN层数
            dropout: dropout率
            attn_method: 注意力计算方法
        """
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=0)  # 词嵌入层
        # 单向GRU
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers,
                          dropout=dropout if n_layers > 1 else 0)
        self.attention = LuongAttention(hid_dim, attn_method)  # 注意力层
        # 全连接层(根据注意力方法调整输入维度)
        self.fc = nn.Linear(hid_dim * 3 if attn_method == "concat" else hid_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs):
        """前向传播
        参数:
            input: 当前输入token
            hidden: 当前隐藏状态
            encoder_outputs: 编码器输出
        返回:
            prediction: 预测的下一个token
            hidden: 新的隐藏状态
            attn_weights: 注意力权重
        """
        input = input.unsqueeze(0)  # [1, batch_size]
        embedded = self.dropout(self.embedding(input))  # [1, batch_size, emb_dim]

        # GRU处理
        output, hidden = self.rnn(embedded, hidden)  # output: [1, batch_size, hid_dim]

        # 计算注意力
        context, attn_weights = self.attention(output, encoder_outputs)

        # 预测下一个token(拼接RNN输出和上下文向量)
        prediction = self.fc(torch.cat((output.squeeze(0), context), dim=1))
        return prediction, hidden, attn_weights


class Seq2Seq(nn.Module):
    """序列到序列模型
    整合编码器和解码器
    """

    def __init__(self, encoder, decoder, device):
        """初始化
        参数:
            encoder: 编码器实例
            decoder: 解码器实例
            device: 计算设备
        """
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        # 确保解码器与编码器层数相同
        assert decoder.rnn.num_layers == encoder.n_layers
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        """前向传播
        参数:
            src: 源序列
            trg: 目标序列(训练时使用)
            teacher_forcing_ratio: 教师强制比例
        返回:
            所有时间步的输出
        """
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim

        # 初始化输出张量
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        # 编码器处理
        encoder_outputs, hidden = self.encoder(src)

        # 第一个输入是<SOS>标记
        input = trg[:, 0]

        # 逐步生成输出序列
        for t in range(1, trg_len):
            # 解码器处理
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
            outputs[t] = output
            # 决定是否使用教师强制
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1

        return outputs


# 训练与评估函数
def train(model, loader, optimizer, criterion, clip):
    """训练函数
    参数:
        model: 模型
        loader: 数据加载器
        optimizer: 优化器
        criterion: 损失函数
        clip: 梯度裁剪阈值
    返回:
        平均损失
    """
    model.train()
    epoch_loss = 0

    for batch in tqdm(loader, desc="Training"):
        src = batch["src"].to(device)
        trg = batch["trg"].to(device)

        optimizer.zero_grad()
        output = model(src, trg)

        # 计算损失(忽略第一个token)
        output = output[1:].reshape(-1, output.shape[-1])
        trg = trg[:, 1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(loader)


def evaluate(model, loader, criterion):
    """评估函数
    参数:
        model: 模型
        loader: 数据加载器
        criterion: 损失函数
    返回:
        平均损失
    """
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            src = batch["src"].to(device)
            trg = batch["trg"].to(device)

            # 评估时不使用教师强制
            output = model(src, trg, teacher_forcing_ratio=0)
            output = output[1:].reshape(-1, output.shape[-1])
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(loader)

# 加载数据(使用opus100数据集的中英翻译部分,只取前10000条作为示例)
dataset = load_dataset("./opus100", "en-zh", split="train[:10000]")
# 划分训练集和验证集
train_val = dataset.train_test_split(test_size=0.2)


# 构建分词器
def get_text_iter(data, lang="zh"):
    """获取文本迭代器
    用于分词器训练
    """
    for item in data["translation"]:
        yield item[lang]


# 训练中文和英文分词器
zh_tokenizer = build_tokenizer(get_text_iter(train_val["train"]))
en_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "en"))

# 创建DataLoader
train_dataset = TranslationDataset(train_val["train"], zh_tokenizer, en_tokenizer)
val_dataset = TranslationDataset(train_val["test"], zh_tokenizer, en_tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)

# 初始化模型
INPUT_DIM = len(zh_tokenizer.get_vocab())  # 中文词汇表大小
OUTPUT_DIM = len(en_tokenizer.get_vocab())  # 英文词汇表大小
ENC_EMB_DIM = 512  # 编码器词嵌入维度
DEC_EMB_DIM = 512  # 解码器词嵌入维度
HID_DIM = 1024  # 隐藏层维度
N_LAYERS = 3  # RNN层数
DROP_RATE = 0.3  # dropout率

# 创建编码器、解码器和seq2seq模型
encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, DROP_RATE)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DROP_RATE, "dot")
model = Seq2Seq(encoder, decoder, device).to(device)

# 训练配置
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 学习率调度器
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充标记的损失
CLIP = 5.0  # 梯度裁剪阈值
N_EPOCHS = 20  # 训练轮数

# 训练循环
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_loader, criterion)

    # 保存最佳模型
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best_model.pt')

    print(f'Epoch: {epoch + 1:02}')
    print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')


def translate(model, sentence, src_tokenizer, trg_tokenizer, max_len=50):
    """翻译函数
    参数:
        model: 训练好的模型
        sentence: 待翻译的句子
        src_tokenizer: 源语言分词器
        trg_tokenizer: 目标语言分词器
        max_len: 最大生成长度
    返回:
        翻译结果
    """
    model.eval()

    # 中文分词并编码
    tokens = ["[SOS]"] + src_tokenizer.encode(sentence).tokens + ["[EOS]"]
    src = torch.tensor([src_tokenizer.token_to_id(t) for t in tokens]).unsqueeze(0).to(device)

    # 初始化目标序列(以<SOS>开始)
    trg_indexes = [trg_tokenizer.token_to_id("[SOS]")]

    # 逐步生成目标序列
    for i in range(max_len):
        trg_tensor = torch.tensor(trg_indexes).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(src, trg_tensor)

        # 获取预测的下一个token
        pred_token = output.argmax(2)[-1].item()
        trg_indexes.append(pred_token)

        # 如果遇到<EOS>则停止
        if pred_token == trg_tokenizer.token_to_id("[EOS]"):
            break

    # 将ID转换为token
    trg_tokens = [trg_tokenizer.id_to_token(i) for i in trg_indexes]
    # 去掉<EOS>和<SOS>并返回
    return ' '.join(trg_tokens[1:-1])


# 测试翻译
test_sentences = [
    "你好世界",
    "深度学习很有趣",
    "今天天气真好"
]

print("\n测试翻译结果:")
for sent in test_sentences:
    translation = translate(model, sent, zh_tokenizer, en_tokenizer)
    print(f"中文: {sent} -> 英文: {translation}")

输出为:

Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00,  1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.54it/s]
Epoch: 01
        Train Loss: 6.443 | Val. Loss: 6.351
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.56it/s]
Epoch: 02
        Train Loss: 6.315 | Val. Loss: 6.400
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00,  1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.53it/s]
Epoch: 03
        Train Loss: 6.307 | Val. Loss: 6.406
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00,  1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.52it/s]
Epoch: 04
        Train Loss: 6.303 | Val. Loss: 6.469
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:26<00:00,  1.21it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.54it/s]
Epoch: 05
        Train Loss: 6.304 | Val. Loss: 6.398
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00,  1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.54it/s]
Epoch: 06
        Train Loss: 6.298 | Val. Loss: 6.421
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:27<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.53it/s]
Epoch: 07
        Train Loss: 6.298 | Val. Loss: 6.459
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.52it/s]
Epoch: 08
        Train Loss: 6.291 | Val. Loss: 6.425
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.54it/s]
Epoch: 09
        Train Loss: 6.293 | Val. Loss: 6.425
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00,  1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.54it/s]
Epoch: 10
        Train Loss: 6.293 | Val. Loss: 6.491
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:26<00:00,  1.21it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.55it/s]
Epoch: 11
        Train Loss: 6.294 | Val. Loss: 6.467
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:27<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.54it/s]
Epoch: 12
        Train Loss: 6.295 | Val. Loss: 6.439
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:32<00:00,  1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.56it/s]
Epoch: 13
        Train Loss: 6.293 | Val. Loss: 6.495
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:30<00:00,  1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.55it/s]
Epoch: 14
        Train Loss: 6.296 | Val. Loss: 6.471
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.56it/s]
Epoch: 15
        Train Loss: 6.300 | Val. Loss: 6.423
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:30<00:00,  1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.55it/s]
Epoch: 16
        Train Loss: 6.298 | Val. Loss: 6.458
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.56it/s]
Epoch: 17
        Train Loss: 6.303 | Val. Loss: 6.510
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.55it/s]
Epoch: 18
        Train Loss: 6.305 | Val. Loss: 6.479
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00,  1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.56it/s]
Epoch: 19
        Train Loss: 6.309 | Val. Loss: 6.585
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00,  1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00,  4.57it/s]
Epoch: 20
        Train Loss: 6.310 | Val. Loss: 6.515

测试翻译结果:
中文: 你好世界 -> 英文: [PAD]
中文: 深度学习很有趣 -> 英文: [PAD]
中文: 今天天气真好 -> 英文: [PAD]

三、Transformer实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import math

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 数据预处理
def build_tokenizer(text_iter, vocab_size=20000):
    tokenizer = Tokenizer(models.Unigram())
    tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
    trainer = trainers.UnigramTrainer(
        vocab_size=vocab_size,
        special_tokens=["[PAD]", "[UNK]", "[SOS]", "[EOS]"],  # 确保包含UNK
        unk_token="[UNK]"  # 显式指定UNK token
    )
    tokenizer.train_from_iterator(text_iter, trainer)
    return tokenizer


def get_text_iter(dataset, language="zh"):
    for item in dataset["translation"]:
        yield item[language]


# 加载数据集
train_data = dataset = load_dataset("opus100", "en-zh")
train_val = train_data["train"].train_test_split(test_size=0.2)

# 构建中英文分词器
zh_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "zh"))
en_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "en"))

# 词汇表
zh_vocab = zh_tokenizer.get_vocab()
en_vocab = en_tokenizer.get_vocab()


class TranslationDataset(Dataset):
    def __init__(self, data, src_tokenizer, trg_tokenizer, max_len=50):
        self.data = data
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer
        self.max_len = max_len

        # 获取所有必须的ID(确保不为None)
        self.src_unk_id = src_tokenizer.token_to_id("[UNK]")
        self.src_sos_id = src_tokenizer.token_to_id("[SOS]")
        self.src_eos_id = src_tokenizer.token_to_id("[EOS]")
        self.trg_unk_id = trg_tokenizer.token_to_id("[UNK]")
        self.trg_sos_id = trg_tokenizer.token_to_id("[SOS]")
        self.trg_eos_id = trg_tokenizer.token_to_id("[EOS]")

        # 验证关键ID是否存在
        self._validate_token_ids()

    def _validate_token_ids(self):
        for name, token_id in [("SRC_UNK", self.src_unk_id),
                               ("SRC_SOS", self.src_sos_id),
                               ("SRC_EOS", self.src_eos_id),
                               ("TRG_UNK", self.trg_unk_id),
                               ("TRG_SOS", self.trg_sos_id),
                               ("TRG_EOS", self.trg_eos_id)]:
            if token_id is None:
                raise ValueError(f"{name} token不存在于词汇表中")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]["translation"]

        # 中文编码(源语言)
        src_tokens = self._process_sequence(item["zh"], self.src_tokenizer,
                                            self.src_sos_id, self.src_eos_id, self.src_unk_id)

        # 英文编码(目标语言)
        trg_tokens = self._process_sequence(item["en"], self.trg_tokenizer,
                                            self.trg_sos_id, self.trg_eos_id, self.trg_unk_id)

        return {
            "src": torch.tensor(src_tokens),
            "trg": torch.tensor(trg_tokens)
        }

    def _process_sequence(self, text, tokenizer, sos_id, eos_id, unk_id):
        """处理单个序列的编码"""
        encoded = tokenizer.encode(text)
        tokens = encoded.tokens[:self.max_len - 2]  # 保留空间给SOS/EOS

        # 转换为ID,确保没有None值
        token_ids = []
        for t in tokens:
            token_id = tokenizer.token_to_id(t)
            token_ids.append(token_id if token_id is not None else unk_id)

        return [sos_id] + token_ids + [eos_id]


def collate_fn(batch):
    src = [item["src"] for item in batch]
    trg = [item["trg"] for item in batch]
    return {
        "src": nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0),
        "trg": nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)
    }


# 创建DataLoader
train_dataset = TranslationDataset(train_val["train"], zh_tokenizer, en_tokenizer)
val_dataset = TranslationDataset(train_val["test"], zh_tokenizer, en_tokenizer)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)


# Transformer模型实现
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_model // n_head
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        # 线性变换并分头
        q = self.w_q(q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        k = self.w_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        v = self.w_v(v).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # 计算输出
        output = torch.matmul(attn, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
        return self.w_o(output), attn


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout2(attn_output))
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout3(ffn_output))
        return x


class Encoder(nn.Module):
    def __init__(self, src_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len):
        super().__init__()
        self.token_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=0)
        self.pos_embed = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        x = self.dropout(self.pos_embed(self.token_embed(src)))
        for layer in self.layers:
            x = layer(x, src_mask)
        return x


class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len):
        super().__init__()
        self.token_embed = nn.Embedding(trg_vocab_size, d_model, padding_idx=0)
        self.pos_embed = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)])
        self.fc_out = nn.Linear(d_model, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, trg, encoder_output, src_mask, tgt_mask):
        x = self.dropout(self.pos_embed(self.token_embed(trg)))
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.fc_out(x)


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, d_model=512, n_layers=6,
                 n_head=8, d_ff=2048, dropout=0.1, max_len=100):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len)
        self.decoder = Decoder(trg_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len)
        self.src_pad_idx = 0
        self.trg_pad_idx = 0

    def make_src_mask(self, src):
        return (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
        return trg_pad_mask & trg_sub_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_trg_mask(trg[:, :-1])
        encoder_output = self.encoder(src, src_mask)
        output = self.decoder(trg[:, :-1], encoder_output, src_mask, tgt_mask)
        return output


# 训练与评估
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0

    for batch in train_loader:
        src = batch["src"].to(device)
        trg = batch["trg"].to(device)

        optimizer.zero_grad()
        output = model(src, trg)

        output_dim = output.shape[-1]
        output = output.reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            src = batch["src"].to(device)
            trg = batch["trg"].to(device)

            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output.reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

# 初始化模型
model = Transformer(
    src_vocab_size=len(zh_vocab),
    trg_vocab_size=len(en_vocab),
    d_model=256,  # 减小模型尺寸便于快速训练
    n_layers=3,
    n_head=4
).to(device)

# 训练配置
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
CLIP = 1.0
N_EPOCHS = 20

# 训练循环
for epoch in range(N_EPOCHS):
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_loader, criterion)

    print(f'Epoch: {epoch + 1:02}')
    print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')


# 翻译测试函数
def translate(model, sentence, src_tokenizer, trg_tokenizer, max_len=50):
    model.eval()

    # 编码源语言
    src_tokens = ["[SOS]"] + src_tokenizer.encode(sentence).tokens + ["[EOS]"]
    src_ids = [src_tokenizer.token_to_id(t) for t in src_tokens]
    src = torch.tensor(src_ids).unsqueeze(0).to(device)  # [1, src_len]

    # 初始化目标序列(始终以SOS开头)
    trg_indexes = [trg_tokenizer.token_to_id("[SOS]")]

    # 逐步解码
    for _ in range(max_len):
        trg_tensor = torch.tensor(trg_indexes).unsqueeze(0).to(device)  # [1, trg_len]

        with torch.no_grad():
            output = model(src, trg_tensor)  # 形状应为 [1, trg_len, vocab_size]

            # 关键修正:安全获取最后一个预测token
            if output.size(1) == 0:  # 处理空输出情况
                pred_token = trg_tokenizer.token_to_id("[UNK]")
            else:
                pred_token = output.argmax(-1)[0, -1].item()  # 获取序列最后一个预测

        trg_indexes.append(pred_token)
        if pred_token == trg_tokenizer.token_to_id("[EOS]"):
            break

    # 转换为文本(跳过SOS和EOS)
    trg_tokens = []
    for i in trg_indexes[1:]:  # 跳过初始的SOS
        if i == trg_tokenizer.token_to_id("[EOS]"):
            break
        trg_tokens.append(trg_tokenizer.id_to_token(i))

    return ' '.join(trg_tokens)


# 测试翻译
test_sentences = [
    "你好世界",
    "深度学习很有趣",
    "今天天气真好"
]

print("\n测试翻译结果:")
for sent in test_sentences:
    translation = translate(model, sent, zh_tokenizer, en_tokenizer)
    print(f"中文: {sent} -> 英文: {translation}")

输出为:

Epoch: 01
        Train Loss: 4.038 | Val. Loss: 3.263
Epoch: 02
        Train Loss: 3.184 | Val. Loss: 2.786
Epoch: 03
        Train Loss: 2.833 | Val. Loss: 2.497
Epoch: 04
        Train Loss: 2.612 | Val. Loss: 2.323
Epoch: 05
        Train Loss: 2.460 | Val. Loss: 2.205
Epoch: 06
        Train Loss: 2.352 | Val. Loss: 2.123
Epoch: 07
        Train Loss: 2.269 | Val. Loss: 2.055
Epoch: 08
        Train Loss: 2.206 | Val. Loss: 2.009
Epoch: 09
        Train Loss: 2.154 | Val. Loss: 1.971
Epoch: 10
        Train Loss: 2.110 | Val. Loss: 1.936
Epoch: 11
        Train Loss: 2.073 | Val. Loss: 1.910
Epoch: 12
        Train Loss: 2.041 | Val. Loss: 1.886
Epoch: 13
        Train Loss: 2.013 | Val. Loss: 1.866
Epoch: 14
        Train Loss: 1.988 | Val. Loss: 1.847
Epoch: 15
        Train Loss: 1.966 | Val. Loss: 1.833
Epoch: 16
        Train Loss: 1.946 | Val. Loss: 1.820
Epoch: 17
        Train Loss: 1.927 | Val. Loss: 1.805
Epoch: 18
        Train Loss: 1.910 | Val. Loss: 1.793
Epoch: 19
        Train Loss: 1.894 | Val. Loss: 1.786
Epoch: 20
        Train Loss: 1.880 | Val. Loss: 1.773

测试翻译结果:
中文: 你好世界 -> 英文: [UNK] Hello THE , FUCK world . .
中文: 深度学习很有趣 -> 英文: [UNK] I of t f . unny
中文: 今天天气真好 -> 英文: [UNK] I weather t today o . n

四、注意力机制变体

1. 稀疏注意力实现

class SparseAttention(nn.Module):
    def __init__(self, block_size=32):
        super().__init__()
        self.block_size = block_size
    
    def forward(self, q, k, v):
        batch_size, seq_len, d_model = q.shape

        # 分块
        q = q.view(batch_size, -1, self.block_size, d_model)
        k = k.view(batch_size, -1, self.block_size, d_model)
        v = v.view(batch_size, -1, self.block_size, d_model)

        # 计算块内注意力
        scores = torch.einsum('bind,bjnd->bnij', q, k) / math.sqrt(d_model)
        attn = F.softmax(scores, dim=-1)
        output = torch.einsum('bnij,bjnd->bind', attn, v)

        # 恢复形状
        output = output.view(batch_size, seq_len, d_model)
        return output, attn

2. 相对位置编码

class RelativePositionEmbedding(nn.Module):
    def __init__(self, max_len=512, d_model=512):
        super().__init__()
        self.emb = nn.Embedding(2 * max_len - 1, d_model)
        self.max_len = max_len
        
    def forward(self, q):
        """
        q: [batch_size, seq_len, d_model]
        """
        seq_len = q.size(1)
        range_vec = torch.arange(seq_len)
        distance_mat = range_vec[None, :] - range_vec[:, None]  # [seq_len, seq_len]
        distance_mat_clipped = torch.clamp(distance_mat + self.max_len - 1, 0, 2 * self.max_len - 2)
        position_emb = self.emb(distance_mat_clipped)  # [seq_len, seq_len, d_model]
        return position_emb

五、性能对比与总结

1.注意力模式可视化

def plot_attention(attention, source, target):
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    cax = ax.matshow(attention, cmap='bone')
    ax.set_xticklabels([''] + source, rotation=90)
    ax.set_yticklabels([''] + target)
    plt.show()

2. 关键演进规律

  1. 信息瓶颈突破:从固定长度上下文到动态注意力分配

  2. 计算效率提升:从RNN的O(n)序列计算到Transformer的并行化

  3. 建模能力增强:从局部依赖到全局关系建模

在下一篇文章中,我们将深入探讨归一化技术对比(BN/LN/IN/GN),分析不同归一化方法的特点和适用场景。


网站公告

今日签到

点亮在社区的每一天
去签到