文章目录
前言
在深度学习领域,注意力机制(Attention Mechanism)已经成为许多模型的核心组件,尤其是在自然语言处理(NLP)和计算机视觉任务中。注意力机制的核心思想是通过计算查询(Query)与键(Key)之间的相关性,动态地为值(Value)分配权重,从而聚焦于最重要的信息。本篇博客将通过 PyTorch 代码,深入探讨注意力汇聚(Attention Pooling)的两种常见评分函数:加性注意力(Additive Attention)和点积注意力(Dot Product Attention)。我们将从代码实现入手,逐步解析其原理,并通过可视化工具展示注意力权重的分布。
本文的目标读者是对深度学习有一定基础、希望通过代码理解注意力机制的实现细节的开发者。所有代码均基于 PyTorch,并在 Jupyter Notebook 中运行和测试。让我们开始吧!
一、掩蔽 Softmax 操作
在注意力机制中,掩蔽 Softmax(Masked Softmax)是一个关键步骤,用于确保模型只关注序列中的有效部分,避免对填充(padding)数据产生影响。我们先来看两个核心函数的实现:sequence_mask
和 masked_softmax
。
1.1 sequence_mask
sequence_mask
函数用于在序列中屏蔽不相关的项。它接收输入序列张量 X
、有效长度张量 valid_len
,并将无效位置替换为指定值(默认值为 0)。
import torch
import torch.nn as nn
def sequence_mask(X, valid_len, value=0):
"""
在序列中屏蔽不相关的项
参数:
X: 输入序列张量,维度 [batch_size, maxlen]
valid_len: 有效长度张量,维度 [batch_size]
value: 填充值,标量,默认为0
返回:
X: 屏蔽后的序列张量,维度 [batch_size, maxlen]
Defined in :numref:`sec_seq2seq_decoder`
"""
# 获取序列的最大长度,维度为标量
maxlen = X.size(1)
# 创建掩码矩阵
# torch.arange(maxlen): 生成 [0, 1, ..., maxlen-1] 的序列,维度 [maxlen]
# [None, :] 将其扩展为 [1, maxlen]
# valid_len[:, None] 将 [batch_size] 扩展为 [batch_size, 1]
# 比较结果 mask 维度为 [batch_size, maxlen]
mask = torch.arange((maxlen), dtype=torch.float32,
device=X.device)[None, :] < valid_len[:, None]
# 使用掩码将 X 中无效位置设为 value
# ~mask 为反向掩码,选择需要填充的位置
X[~mask] = value
return X
这个函数的工作原理是:
- 通过
torch.arange(maxlen)
生成一个从 0 到maxlen-1
的序列,并扩展为与批量大小匹配的形状。 - 使用广播机制,将
valid_len
与生成的序列比较,生成布尔掩码mask
。 - 根据掩码,将无效位置(即超出有效长度的部分)替换为
value
。
1.2 masked_softmax
masked_softmax
函数在 Softmax 操作中加入掩蔽机制,确保无效位置的注意力权重为 0。
def masked_softmax(X, valid_lens):
"""
通过在最后一个轴上掩蔽元素来执行softmax操作
参数:
X: 三维张量 (batch_size, seq_len, feature_dim)
valid_lens: 一维张量 (batch_size,) 或二维张量 (batch_size, seq_len),表示有效长度
返回:
经过masked softmax处理的张量 (batch_size, seq_len, feature_dim)
"""
if valid_lens is None:
# 当没有指定有效长度时,直接执行标准softmax
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape # shape: (batch_size, seq_len, feature_dim)
if valid_lens.dim() == 1:
# 将一维的valid_lens重复扩展到与X的第二维匹配
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
# 将二维的valid_lens展平为一维
valid_lens = valid_lens.reshape(-1)
# 在最后一轴上对被掩蔽的元素使用非常大的负值替换,使其softmax输出为0
X = sequence_mask(X.reshape(-1, shape[-1]),
valid_lens,
value=-1e6)
# 执