自注意力、多头注意力与交叉注意力的PyTorch代码对比
1. 自注意力 (Self-Attention)
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
# 投影矩阵:Q/K/V共享输入维度
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
x: (batch_size, seq_len, embed_dim)
"""
# 1. 生成Q/K/V - 全部来自同一输入
Q = self.query(x) # (B, L, D)
K = self.key(x) # (B, L, D)
V = self.value(x) # (B, L, D)
# 2. 计算注意力分数
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))
# 3. 注意力权重归一化
attn_weights = self.softmax(attn_scores) # (B, L, L)
# 4. 加权求和
output = torch.matmul(attn_weights, V) # (B, L, D)
return output
核心特征:
- Q/K/V全部来自同一个输入序列
- 注意力分数矩阵维度为(L, L),表示序列内部的关系
- 输出序列长度和维度不变
2. 多头注意力 (Multi-Head Attention)
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# 确保可分割
assert self.head_dim * num_heads == embed_dim, "Embed dim must be divisible by num_heads"
# 多头投影矩阵
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
# 输出层
self.fc_out = nn.Linear(embed_dim, embed_dim)
self.softmax = nn.Softmax(dim=-1)
def split_heads(self, x):
"""分割为多头"""
batch_size = x.size(0)
# (B, L, D) -> (B, L, H, HD) -> (B, H, L, HD)
return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
def forward(self, x):
"""多头自注意力"""
# 1. 生成Q/K/V
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# 2. 分割为多头
Q = self.split_heads(Q) # (B, H, L, HD)
K = self.split_heads(K)
V = self.split_heads(V)
# 3. 计算注意力分数
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
# 4. 注意力权重归一化
attn_weights = self.softmax(attn_scores) # (B, H, L, L)
# 5. 加权求和
attention = torch.matmul(attn_weights, V) # (B, H, L, HD)
# 6. 合并多头
attention = attention.transpose(1, 2).contiguous() # (B, L, H, HD)
attention = attention.view(attention.size(0), -1, self.embed_dim) # (B, L, D)
# 7. 输出投影
output = self.fc_out(attention)
return output
核心特征:
- 基于自注意力扩展
- 额外的分割(head splitting)和合并操作
- 每个头在降维后的子空间(HD)中计算
- 最终通过全连接层融合多头信息
3. 交叉注意力 (Cross-Attention)
class CrossAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
# Query来自序列A,Key/Value来自序列B
self.query = nn.Linear(embed_dim, embed_dim) # for sequence A
self.key = nn.Linear(embed_dim, embed_dim) # for sequence B
self.value = nn.Linear(embed_dim, embed_dim) # for sequence B
self.softmax = nn.Softmax(dim=-1)
def forward(self, x_a, x_b):
"""
x_a: (batch_size, len_a, embed_dim) 序列A
x_b: (batch_size, len_b, embed_dim) 序列B
"""
# 1. 生成Q/K/V - 来自不同输入源
Q = self.query(x_a) # 来自序列A (B, La, D)
K = self.key(x_b) # 来自序列B (B, Lb, D)
V = self.value(x_b) # 来自序列B (B, Lb, D)
# 2. 计算注意力分数 (序列A到序列B的映射)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))
# 3. 注意力权重归一化
attn_weights = self.softmax(attn_scores) # (B, La, Lb)
# 4. 加权求和
output = torch.matmul(attn_weights, V) # (B, La, D)
return output
核心特征:
- Q来自一个序列,K/V来自另一个序列
- 注意力矩阵维度为(La, Lb),表示序列间关系
- 输出序列长度与查询序列相同(La),维度不变
- 不要求两个序列长度相同
三者的核心对比
特性 | 自注意力 | 多头注意力 | 交叉注意力 |
---|---|---|---|
输入序列数量 | 1个 | 1个 | 2个 |
Q来源 | 自身 | 自身 | 序列A |
K/V来源 | 自身 | 自身 | 序列B |
维度变换 | 无 | 分割头+合并 | 无 |
注意力矩阵 | (L, L) | (H, L, L) | (La, Lb) |
输出长度 | L | L | La |
主要用途 | 序列内关系 | 多角度特征提取 | 序列间关系建模 |
计算复杂度 | O(L²·D) | O(H·L²·HD) | O(La·Lb·D) |
使用场景示例
# 示例:序列长度均为5,嵌入维度128
x = torch.randn(2, 5, 128) # batch_size=2, seq_len=5, embed_dim=128
y = torch.randn(2, 3, 128) # 不同长度序列
# 1. 自注意力
self_attn = SelfAttention(embed_dim=128)
output_self = self_attn(x) # (2, 5, 128)
# 2. 多头注意力 (8头)
multihead_attn = MultiHeadAttention(embed_dim=128, num_heads=8)
output_multi = multihead_attn(x) # (2, 5, 128)
# 3. 交叉注意力
cross_attn = CrossAttention(embed_dim=128)
output_cross = cross_attn(x, y) # (2, 5, 128) - 保持查询序列长度
性能优化技巧
融合计算:现代PyTorch版本推荐使用优化API
# PyTorch 1.12+ 优化实现 output = F.scaled_dot_product_attention(Q, K, V, attn_mask=None)
内存优化:使用计算过程重算减少内存占用
with torch.cuda.amp.autocast(enabled=True): output = some_attention(Q, K, V)
稀疏注意力:对大序列使用稀疏矩阵
from transformers.models.longformer.modeling_longformer import LongformerSelfAttention