自注意力,多头注意力,交叉注意力代码对比

发布于:2025-06-05 ⋅ 阅读:(20) ⋅ 点赞:(0)

自注意力、多头注意力与交叉注意力的PyTorch代码对比

1. 自注意力 (Self-Attention)

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        # 投影矩阵:Q/K/V共享输入维度
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        """
        x: (batch_size, seq_len, embed_dim)
        """
        # 1. 生成Q/K/V - 全部来自同一输入
        Q = self.query(x)  # (B, L, D)
        K = self.key(x)    # (B, L, D)
        V = self.value(x)  # (B, L, D)
        
        # 2. 计算注意力分数
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))
        
        # 3. 注意力权重归一化
        attn_weights = self.softmax(attn_scores)  # (B, L, L)
        
        # 4. 加权求和
        output = torch.matmul(attn_weights, V)  # (B, L, D)
        return output

核心特征

  • Q/K/V全部来自同一个输入序列
  • 注意力分数矩阵维度为(L, L),表示序列内部的关系
  • 输出序列长度和维度不变

2. 多头注意力 (Multi-Head Attention)

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # 确保可分割
        assert self.head_dim * num_heads == embed_dim, "Embed dim must be divisible by num_heads"
        
        # 多头投影矩阵
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        
        # 输出层
        self.fc_out = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)
        
    def split_heads(self, x):
        """分割为多头"""
        batch_size = x.size(0)
        # (B, L, D) -> (B, L, H, HD) -> (B, H, L, HD)
        return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
    def forward(self, x):
        """多头自注意力"""
        # 1. 生成Q/K/V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # 2. 分割为多头
        Q = self.split_heads(Q)  # (B, H, L, HD)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # 3. 计算注意力分数
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
        
        # 4. 注意力权重归一化
        attn_weights = self.softmax(attn_scores)  # (B, H, L, L)
        
        # 5. 加权求和
        attention = torch.matmul(attn_weights, V)  # (B, H, L, HD)
        
        # 6. 合并多头
        attention = attention.transpose(1, 2).contiguous()  # (B, L, H, HD)
        attention = attention.view(attention.size(0), -1, self.embed_dim)  # (B, L, D)
        
        # 7. 输出投影
        output = self.fc_out(attention)
        return output

核心特征

  • 基于自注意力扩展
  • 额外的分割(head splitting)和合并操作
  • 每个头在降维后的子空间(HD)中计算
  • 最终通过全连接层融合多头信息

3. 交叉注意力 (Cross-Attention)

class CrossAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        # Query来自序列A,Key/Value来自序列B
        self.query = nn.Linear(embed_dim, embed_dim)  # for sequence A
        self.key = nn.Linear(embed_dim, embed_dim)   # for sequence B
        self.value = nn.Linear(embed_dim, embed_dim) # for sequence B
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x_a, x_b):
        """
        x_a: (batch_size, len_a, embed_dim)  序列A
        x_b: (batch_size, len_b, embed_dim)  序列B
        """
        # 1. 生成Q/K/V - 来自不同输入源
        Q = self.query(x_a)   # 来自序列A (B, La, D)
        K = self.key(x_b)     # 来自序列B (B, Lb, D)
        V = self.value(x_b)   # 来自序列B (B, Lb, D)
        
        # 2. 计算注意力分数 (序列A到序列B的映射)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))
        
        # 3. 注意力权重归一化
        attn_weights = self.softmax(attn_scores)  # (B, La, Lb)
        
        # 4. 加权求和
        output = torch.matmul(attn_weights, V)  # (B, La, D)
        return output

核心特征

  • Q来自一个序列,K/V来自另一个序列
  • 注意力矩阵维度为(La, Lb),表示序列间关系
  • 输出序列长度与查询序列相同(La),维度不变
  • 不要求两个序列长度相同

三者的核心对比

特性 自注意力 多头注意力 交叉注意力
输入序列数量 1个 1个 2个
Q来源 自身 自身 序列A
K/V来源 自身 自身 序列B
维度变换 分割头+合并
注意力矩阵 (L, L) (H, L, L) (La, Lb)
输出长度 L L La
主要用途 序列内关系 多角度特征提取 序列间关系建模
计算复杂度 O(L²·D) O(H·L²·HD) O(La·Lb·D)

使用场景示例

# 示例:序列长度均为5,嵌入维度128
x = torch.randn(2, 5, 128)  # batch_size=2, seq_len=5, embed_dim=128
y = torch.randn(2, 3, 128)  # 不同长度序列

# 1. 自注意力
self_attn = SelfAttention(embed_dim=128)
output_self = self_attn(x)  # (2, 5, 128)

# 2. 多头注意力 (8头)
multihead_attn = MultiHeadAttention(embed_dim=128, num_heads=8)
output_multi = multihead_attn(x)  # (2, 5, 128)

# 3. 交叉注意力
cross_attn = CrossAttention(embed_dim=128)
output_cross = cross_attn(x, y)  # (2, 5, 128) - 保持查询序列长度

性能优化技巧

  1. 融合计算:现代PyTorch版本推荐使用优化API

    # PyTorch 1.12+ 优化实现
    output = F.scaled_dot_product_attention(Q, K, V, attn_mask=None)
    
  2. 内存优化:使用计算过程重算减少内存占用

    with torch.cuda.amp.autocast(enabled=True):
        output = some_attention(Q, K, V)
    
  3. 稀疏注意力:对大序列使用稀疏矩阵

    from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
    

网站公告

今日签到

点亮在社区的每一天
去签到