多头注意力机制(MHA)使用精要
核心作用: 捕捉序列数据的全局依赖关系,让每个时间点都能关注到所有其他时间点。
关键参数 (__init__
)
embed_dim
: 特征维度 (C)。必须与输入到MHA层的数据的特征维度完全一致。num_heads
: 头的数量。embed_dim
必须能被num_heads
整除。batch_first=True
: 务必设为True
。这规定了MHA期望的输入格式为(N, L, C)
。
实现蓝图 (forward
pass)
在卷积网络(输入为 (N, C, L)
)中使用MHA,遵循以下三步即可:
格式转换 (Permute In):
x = x.permute(0, 2, 1)
目的:将
(N, C, L)
转换为MHA期望的(N, L, C)
。
应用注意力块 (Attention Block):
attn_out, _ = self.mha(x, x, x)
x = self.norm(x + attn_out)
目的:执行自注意力计算,并用残差连接和层归一化稳定训练。
格式恢复 (Permute Back):
x = x.permute(0, 2, 1)
目的:将
(N, L, C)
转换回(N, C, L)
,以适配后续的卷积层。
黄金法则: MHA的 embed_dim
参数值,必须等于你的数据在进入MHA模块时的特征维度(通道数C),而不是最原始信号的维度。
import torch
import torch.nn as nn
class AttentionBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
super(AttentionBlock, self).__init__()
# 确保 embed_dim 能被 num_heads 整除
if embed_dim % num_heads != 0:
raise ValueError(f"embed_dim ({embed_dim}) 必须能被 num_heads ({num_heads}) 整除。")
self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
# x 的输入格式应为 (N, C, L),这是CNN的典型输出格式
N, C, L = x.shape
# --- 配方第1步: 格式准备 ---
# (N, C, L) -> (N, L, C)
x_permuted = x.permute(0, 2, 1)
# --- 配方第2步: 自注意力计算 ---
attn_output, _ = self.mha(x_permuted, x_permuted, x_permuted)
# --- 配方第3步: 稳定与融合 ---
# 残差连接 + 层归一化
x_stabilized = self.norm(x_permuted + attn_output)
# --- 配方第4步: 格式恢复 ---
# (N, L, C) -> (N, C, L)
final_output = x_stabilized.permute(0, 2, 1)
return final_output
# --- 使用示例 ---
# 假设我们有一个来自CNN的输出
cnn_output = torch.randn(32, 64, 1024) # (N, C, L)
# 创建并使用注意力块
attention_block = AttentionBlock(embed_dim=64, num_heads=8)
processed_output = attention_block(cnn_output)
print(f"输入形状: {cnn_output.shape}")
print(f"输出形状: {processed_output.shape}") # 输出形状应与输入完全相同