【NLP】Transformer网络结构(2)

发布于:2025-04-08 ⋅ 阅读:(19) ⋅ 点赞:(0)

一、Transformer 整体架构

Transformer 由 Encoder 和 Decoder 堆叠组成,每个 Encoder/Decoder 层包含以下核心模块:

  • Encoder 层:Multi-Head Self-Attention → Add & LayerNorm → Feed-Forward → Add & LayerNorm
  • Decoder 层:Masked Multi-Head Self-Attention → Add & LayerNorm → Cross-Attention (Encoder-Decoder) → Add & LayerNorm → Feed-Forward → Add & LayerNorm

二、核心组件代码实现与解析

1. 位置编码 (Positional Embeddings)

原理:将序列中每个 token 的位置信息编码为向量,与词向量相加后输入模型。常用正弦/余弦函数或可学习参数实现。

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置
        self.register_buffer('pe', pe)  # 非可学习参数

    def forward(self, x):
        # x: (B, seq_len, d_model)
        x = x + self.pe[:x.size(1), :]  # 仅取前seq_len个位置的编码
        return x

代码解析

  • 固定模式编码:使用正弦函数生成位置编码,区分奇偶位置维度。
  • 可扩展性max_len 预设最大序列长度,适用于不同输入长度。
  • 加法融合:直接将位置编码与词嵌入相加,保留位置信息。

2. 层归一化 (Layer Normalization)

原理:对每个样本的特征维度(而非批量维度)进行归一化,稳定训练。

# PyTorch 内置实现
layer_norm = nn.LayerNorm(d_model)

关键特性

  • 归一化维度:沿特征维度(d_model)计算均值和方差。
  • 残差连接:常与残差结构配合使用(如 x + sublayer(x))。

3. 前馈网络 (Feed-Forward Layer)

原理:通过两个线性变换与激活函数,增强模型非线性表达能力。

class FeedForward(nn.Module):
    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

代码解析

  • 维度扩展:先升维(d_model → dim_feedforward)再降维,增强特征交互。
  • 激活函数:使用 ReLU 引入非线性。

4. Encoder 层
class EncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead)
        self.ffn = FeedForward(d_model, dim_feedforward, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask=None):
        # 1. Self-Attention
        attn_output, _ = self.self_attn(src, src, src, mask=src_mask)
        src = src + self.dropout(attn_output)  # 残差连接
        src = self.norm1(src)
        
        # 2. Feed-Forward
        ffn_output = self.ffn(src)
        src = src + self.dropout(ffn_output)  # 残差连接
        src = self.norm2(src)
        return src

代码解析

  • 残差连接x + dropout(sublayer(x)) 缓解梯度消失。
  • 层归一化顺序:先残差连接后归一化(Post-LN 结构)。

5. Decoder 层
class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead)
        self.cross_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead)
        self.ffn = FeedForward(d_model, dim_feedforward, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        # 1. Masked Self-Attention
        attn_output, _ = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)
        tgt = tgt + self.dropout(attn_output)
        tgt = self.norm1(tgt)
        
        # 2. Encoder-Decoder Cross Attention
        attn_output, _ = self.cross_attn(tgt, memory, memory, mask=memory_mask)
        tgt = tgt + self.dropout(attn_output)
        tgt = self.norm2(tgt)
        
        # 3. Feed-Forward
        ffn_output = self.ffn(tgt)
        tgt = tgt + self.dropout(ffn_output)
        tgt = self.norm3(tgt)
        return tgt

代码解析

  • Masked Self-Attention:使用 tgt_mask 遮挡未来词(如对角线矩阵)。
  • Cross-Attentiontgt 作为 Query,memory(Encoder 输出)作为 Key/Value。

6. 完整 Transformer 模型
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, 
                 num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        # 1. 嵌入层与位置编码
        self.src_embed = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        
        # 2. Encoder 和 Decoder 堆叠
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # 3. 输出层
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 编码器
        src_emb = self.pos_encoder(self.src_embed(src))
        for layer in self.encoder_layers:
            src_emb = layer(src_emb, src_mask)
        
        # 解码器
        tgt_emb = self.pos_encoder(self.tgt_embed(tgt))
        for layer in self.decoder_layers:
            tgt_emb = layer(tgt_emb, src_emb, tgt_mask, None)
        
        # 输出预测
        output = self.fc_out(tgt_emb)
        return output

关键参数

  • src_vocab_size:源语言词表大小。
  • tgt_vocab_size:目标语言词表大小。
  • d_model:模型维度(通常与词嵌入维度一致)。
  • nhead:多头注意力头数。
  • num_encoder/decoder_layers:Encoder/Decoder 层数。

三、使用示例

# 参数设置
src_vocab_size = 10000  # 源语言词表大小
tgt_vocab_size = 8000   # 目标语言词表大小
d_model = 512
nhead = 8
num_layers = 6
batch_size = 32
seq_len = 50

# 初始化模型
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, nhead, num_layers, num_layers)

# 随机生成输入
src = torch.randint(0, src_vocab_size, (batch_size, seq_len))  # (B, seq_len)
tgt = torch.randint(0, tgt_vocab_size, (batch_size, seq_len))  # (B, seq_len)

# 前向计算
output = model(src, tgt)
print("Output shape:", output.shape)  # (B, seq_len, tgt_vocab_size)

网站公告

今日签到

点亮在社区的每一天
去签到