手动实现 Transformer 模型

发布于:2025-05-18 ⋅ 阅读:(24) ⋅ 点赞:(0)

本文使用 Pytorch 库手动实现了传统 Transformer 模型中的多头自注意力机制、残差连接和层归一化、前馈层、编码器、解码器等子模块,进而实现了对 Transformer 模型的构建。

"""
@Title: 解析 Transformer
@Time: 2025/5/10
@Author: Michael Jie
"""

import math

import torch
import torch.nn.functional as F
from torch import nn, Tensor


# 缩放点积注意力机制 (Scaled Dot-Product Attention)
class Attention(nn.Module):
    def __init__(self, causal: bool = True) -> None:
        """
        注意力公式:Attention(Q, K, V) = softmax(Q · K / sqrt(d_k)) · V

        Args:
            causal: 是否自动生成因果掩码,默认为 True
        """
        super(Attention, self).__init__()
        self.causal = causal

    def forward(self,
                q: Tensor,
                k: Tensor,
                v: Tensor,
                padding_mask: Tensor = None,
                attn_mask: Tensor = None) -> tuple[Tensor, Tensor]:
        """
        填充掩码:处理变长序列,避免填充影响注意力计算
        因果掩码:防止解码器在训练时看到未来的信息

        Args:
            q: 查询 shape=(..., seq_len_q, d_k)
            k: 键 shape=(..., seq_len_k, d_k)
            v: 值 shape=(..., seq_len_k, d_v)
            padding_mask: 填充掩码 shape=(..., seq_len_k)
            attn_mask: 因果掩码 shape=(..., seq_len_q, seq_len_k)

        Returns:
            output: 输出 shape=(..., seq_len_q, d_v)
            weights: 注意力权重 shape=(..., seq_len_q, seq_len_k)
        """
        # 注意力分数
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        # 应用填充掩码
        if padding_mask is not None:
            # 广播 (..., 1, seq_len_k)
            scores = scores.masked_fill(padding_mask.unsqueeze(-2), float("-inf"))

        # 自动生成因果掩码,优先使用自定义的因果掩码
        seq_len_q, seq_len_k = q.size(-2), k.size(-2)
        if self.causal and attn_mask is None:
            attn_mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=1).bool()
        # 应用因果掩码
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float("-inf"))

        # 注意力权重
        weights = F.softmax(scores, dim=-1)
        # 再次应用填充掩码,确保填充位置的注意力权重为 0
        if padding_mask is not None:
            weights = weights.masked_fill(padding_mask.unsqueeze(-2), 0)
        # 乘以 v 得到输出
        output = torch.matmul(weights, v)

        return output, weights


# 自注意力机制 (Self Attention)
class SelfAttention(nn.Module):
    def __init__(self, d_model: int = 512) -> None:
        """
        自注意力机制是注意力机制的一种特殊形式,其中 Q、K、V 都来自同一输入序列,
        其能够捕捉序列内部元素之间的关系,不依赖 RNN 或 CNN,直接建模长距离依赖。

        Args:
            d_model: 特征维度,默认为 512
        """
        super(SelfAttention, self).__init__()
        self.attention = Attention()  # 注意力机制
        # 合并 Q、K、V 的线性变换
        self.linear_qkv = nn.Linear(d_model, d_model * 3)
        self.linear_out = nn.Linear(d_model, d_model)

    def forward(self,
                x: Tensor,
                padding_mask: Tensor = None,
                attn_mask: Tensor = None) -> Tensor:
        """
        _

        Args:
            x: 词嵌入 shape=(batch_size, seq_len, d_model)
            padding_mask: 填充掩码 shape=(batch_size, seq_len)
            attn_mask: 因果掩码 shape=(seq_len, seq_len)

        Returns:
            output: 输出 shape=(batch_size, seq_len, d_model)
        """
        # 通过线性层同时生成 Q、K、V
        qkv = self.linear_qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)  # (batch_size, seq_len, d_model)
        # 应用注意力机制
        output, weights = self.attention(q, k, v, padding_mask, attn_mask)

        return self.linear_out(output)


# 多头自注意力机制 (Multi-Head Self Attention)
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int = 512, num_heads: int = 8) -> None:
        """
        多头自注意力机制是自注意力机制的扩展,通过将输入特征分割成多个头,
        每个头独立计算注意力,然后将结果拼接起来,从而提高模型的多角度表达能力。

        Args:
            d_model: 特征维度,默认为 512
            num_heads: 头数,默认为 8
        """
        super(MultiHeadSelfAttention, self).__init__()
        if d_model % num_heads != 0:
            raise ValueError(f"d_model must be divisible by num_heads, but got {d_model} and {num_heads}")
        self.num_heads = num_heads
        self.attention = Attention()  # 注意力机制
        # 分别对 Q、K、V 进行线性变换
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)

    def forward(self,
                q: Tensor,
                k: Tensor,
                v: Tensor,
                padding_mask: Tensor = None,
                attn_mask: Tensor = None) -> Tensor:
        """
        Q、K、V 在不同的自注意力模块中的来源可能不同,
        在编解码器自注意力中,Q 来自解码器的输入,K、V 来自编码器的输出。

        Args:
            q: 查询 shape=(batch_size, seq_len, d_model)
            k: 键 shape=(batch_size, seq_len / seq_len_k, d_model)
            v: 值 shape=(batch_size, seq_len / seq_len_k, d_model)
            padding_mask: 填充掩码 shape=(batch_size, seq_len / seq_len_k)
            attn_mask: 因果掩码 shape=(seq_len / seq_len_k, seq_len / seq_len_k)

        Returns:
            output: 输出 shape=(batch_size, seq_len, d_model)
        """
        q = self.linear_q(q)
        k = self.linear_k(k)
        v = self.linear_v(v)
        batch_size, seq_len, seq_len_k = q.size(0), q.size(1), k.size(1)
        # (batch_size, num_heads, seq_len, d_k)
        q = q.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        k = k.view(batch_size, seq_len_k, self.num_heads, -1).transpose(1, 2)
        v = v.view(batch_size, seq_len_k, self.num_heads, -1).transpose(1, 2)
        # 调整掩码形状以匹配多头
        if padding_mask is not None:
            padding_mask = padding_mask.unsqueeze(1)  # (batch_size, 1, seq_len)
        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)  # (1, seq_len, seq_len)
        # 应用注意力机制
        output, weights = self.attention(q, k, v, padding_mask, attn_mask)
        # 拼接
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        return self.linear_out(output)


# 残差连接和层归一化 (Add&Norm)
class AddNorm(nn.Module):
    def __init__(self, d_model: int = 512) -> None:
        """
        Add&Norm 层结合了两种操作:残差连接和层归一化,
        可以使模型在训练过程中更加稳定,并且通过堆叠多个这样的层来构建更深的模型。

        Args:
            d_model: 特征维度,默认为 512
        """
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)  # 层归一化

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        return self.norm(x + y)


# 前馈层 (FeedForward Layer)
class FeedForward(nn.Module):
    def __init__(self,
                 input_dim: int = 512,
                 hidden_dim: int = 2048,
                 activation: str = "relu",
                 dropout: float = 0.1) -> None:
        """
        全连接层(扩大维度) -> 激活函数 -> 全连接层(恢复原始维度)
        通过非线性变换进一步提取和增强特征,使模型具备更强的模式识别和语义组合能力。

        Args:
            input_dim: 输入维度,默认为 512
            hidden_dim: 隐藏层维度,默认为 2048
            activation: 激活函数,默认为 "relu"
                -支持:"sigmoid", "tanh", "relu", "gelu", "leaky_relu", "elu"
            dropout: 丢弃率,默认为 0.1
        """
        super(FeedForward, self).__init__()
        match activation:  # 切换不同的激活函数
            case "sigmoid":
                activation = nn.Sigmoid()
            case "tanh":
                activation = nn.Tanh()
            case "relu":
                activation = nn.ReLU()
            case "gelu":
                activation = nn.GELU()
            case "leaky_relu":
                activation = nn.LeakyReLU()
            case "elu":
                activation = nn.ELU()
            case _:
                raise ValueError(f"Unsupported activation function: {activation}")
        # Linear -> activation -> Dropout -> Linear
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            activation,
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.ffn(x)


# 编码层
class EncoderLayer(nn.Module):
    def __init__(self,
                 d_model: int = 512,
                 num_heads: int = 8,
                 dim_feedforward: int = 2048,
                 dropout: float = 0.1) -> None:
        """
        MultiHeadSelfAttention -> AddNorm -> FeedForward -> AddNorm

        Args:
            d_model: 特征维度,默认为 512
            num_heads: 头数,默认为 8
            dim_feedforward: FFN 隐藏层维度,默认为 2048
            dropout: 丢弃率,默认为 0.1
        """
        super(EncoderLayer, self).__init__()
        # 多头自注意力层
        self.attn = MultiHeadSelfAttention(d_model, num_heads)
        # Add&Norm 层
        self.norm1 = AddNorm(d_model)
        self.norm2 = AddNorm(d_model)
        # 前馈层
        self.ffn = FeedForward(d_model, dim_feedforward, dropout=dropout)
        # 丢弃层
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self,
                x: Tensor,
                padding_mask: Tensor = None,
                attn_mask: Tensor = None) -> Tensor:
        x = self.norm1(x, self.dropout1(self.attn(x, x, x, padding_mask, attn_mask)))
        x = self.norm2(x, self.dropout2(self.ffn(x)))

        return x


# 编码器
class Encoder(nn.Module):
    def __init__(self, num_layers: int = 6, **params) -> None:
        """
        编码器由多个编码层组成,每个编码层结构相同但并不共享参数。

        Args:
            num_layers: 层数,默认为 6
            **params: 编码层参数,参考 EncoderLayer
        """
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(**params)
            for _ in range(num_layers)
        ])

    def forward(self,
                x: Tensor,
                padding_mask: Tensor = None,
                attn_mask: Tensor = None) -> Tensor:
        for layer in self.layers:  # 逐层传递
            x = layer(x, padding_mask, attn_mask)

        return x


# 解码层
class DecoderLayer(nn.Module):
    def __init__(self,
                 d_model: int = 512,
                 num_heads: int = 8,
                 dim_feedforward: int = 2048,
                 dropout: float = 0.1) -> None:
        """
        MultiHeadSelfAttention -> AddNorm -> MultiHeadSelfAttention -> AddNorm -> FeedForward -> AddNorm

        Args:
            d_model: 特征维度,默认为 512
            num_heads: 头数,默认为 8
            dim_feedforward: FFN 隐藏层维度,默认为 2048
            dropout: 丢弃率,默认为 0.1
        """
        super(DecoderLayer, self).__init__()
        # 多头自注意力层
        self.attn = MultiHeadSelfAttention(d_model, num_heads)
        self.cross_attn = MultiHeadSelfAttention(d_model, num_heads)
        # Add&Norm 层
        self.norm1 = AddNorm(d_model)
        self.norm2 = AddNorm(d_model)
        self.norm3 = AddNorm(d_model)
        # 前馈层
        self.ffn = FeedForward(d_model, dim_feedforward, dropout=dropout)
        # 丢弃层
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self,
                y: Tensor,
                memory: Tensor,
                padding_mask_y: Tensor = None,
                padding_mask_memory: Tensor = None,
                attn_mask_y: Tensor = None,
                attn_mask_memory: Tensor = None) -> None:
        x = y
        x = self.norm1(x, self.dropout1(self.attn(x, x, x, padding_mask_y, attn_mask_y)))
        x = self.norm2(x, self.dropout2(self.attn(x, memory, memory, padding_mask_memory, attn_mask_memory)))
        x = self.norm3(x, self.dropout3(self.ffn(x)))

        return x


# 解码器
class Decoder(nn.Module):
    def __init__(self, num_layers: int = 6, **params) -> None:
        """
        解码器由多个解码层组成,每个解码层结构相同但并不共享参数。

        Args:
            num_layers: 层数,默认为 6
            **params: 解码层参数,参考 DecoderLayer
        """
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(**params)
            for _ in range(num_layers)
        ])

    def forward(self,
                y: Tensor,
                memory: Tensor,
                padding_mask_y: Tensor = None,
                padding_mask_memory: Tensor = None,
                attn_mask_y: Tensor = None,
                attn_mask_memory: Tensor = None) -> Tensor:
        x = y
        for layer in self.layers:  # 逐层传递
            x = layer(y, memory, padding_mask_y, padding_mask_memory, attn_mask_y, attn_mask_memory)

        return x


# Transformer
class Transformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 **params) -> None:
        """
        transformer 是标准的编码器-解码器结构

        Args:
            num_encoder_layers: 编码器层数,默认为 6
            num_decoder_layers: 解码器层数,默认为 6
            **params: 编解码层参数,参考 EncoderLayer 和 DecoderLayer
        """
        super(Transformer, self).__init__()
        self.encoder = Encoder(num_encoder_layers, **params)  # 编码器
        self.decoder = Decoder(num_decoder_layers, **params)  # 解码器

    def forward(self,
                x: Tensor,
                y: Tensor,
                padding_mask_x: Tensor = None,
                padding_mask_y: Tensor = None,
                padding_mask_memory: Tensor = None,
                attn_mask_x: Tensor = None,
                attn_mask_y: Tensor = None,
                attn_mask_memory: Tensor = None) -> Tensor:
        memory = self.encoder(x, padding_mask_x, attn_mask_x)
        output = self.decoder(y, memory, padding_mask_y, padding_mask_memory, attn_mask_y, attn_mask_memory)

        return output


if __name__ == '__main__':
    # attention = Attention(True)
    # t1, t2 = attention(
    #     torch.rand((2, 3, 64)),
    #     torch.rand((2, 5, 64)),
    #     torch.rand((2, 5, 512)),
    #     torch.tensor([[False, True, True, True, True],
    #                   [False, False, False, False, True]])
    # )
    # print(t1.shape, t2.shape)

    # self_attention = SelfAttention()
    # t3 = self_attention(
    #     torch.rand((2, 5, 512)),
    #     torch.tensor([[False, False, False, True, True],
    #                   [False, False, True, True, True]])
    # )
    # print(t3.shape)

    # multi_head_self_attention = MultiHeadSelfAttention(num_heads=2)
    # t4 = multi_head_self_attention(
    #     torch.rand((2, 3, 512)),
    #     torch.rand((2, 5, 512)),
    #     torch.rand((2, 5, 512)),
    #     torch.tensor([[False, False, False, True, True],
    #                   [False, False, True, True, True]])
    # )
    # print(t4.shape)

    # encoder_layer = EncoderLayer()
    # t5 = encoder_layer(
    #     torch.rand((2, 5, 512)),
    #     torch.tensor([[False, False, False, True, True],
    #                   [False, False, True, True, True]])
    # )
    # print(t5.shape)

    # encoder = Encoder(dropout=0.2)
    # t6 = encoder(
    #     torch.rand((2, 5, 512)),
    #     torch.tensor([[False, False, False, True, True],
    #                   [False, False, True, True, True]])
    # )
    # print(t6.shape)

    # decoder_layer = DecoderLayer()
    # t7 = decoder_layer(
    #     torch.rand((2, 3, 512)),
    #     torch.rand((2, 5, 512)),
    #     torch.tensor([[False, False, False],
    #                   [False, False, True]]),
    #     torch.tensor([[False, False, False, True, True],
    #                   [False, False, True, True, True]])
    # )
    # print(t7.shape)

    # decoder = Decoder()
    # t8 = decoder(
    #     torch.rand((2, 3, 512)),
    #     torch.rand((2, 5, 512)),
    #     torch.tensor([[False, False, False],
    #                   [False, False, True]]),
    #     torch.tensor([[False, False, False, True, True],
    #                   [False, False, True, True, True]])
    # )
    # print(t8.shape)

    transformer = Transformer()
    t9 = transformer(
        torch.rand((2, 5, 512)),
        torch.rand((2, 3, 512)),
        torch.tensor([[False, False, False, True, True],
                      [False, False, True, True, True]]),
        torch.tensor([[False, False, False],
                      [False, False, True]]),
    )
    print(t9.shape)


网站公告

今日签到

点亮在社区的每一天
去签到