从代码学习深度学习 - 注意力汇聚:注意力评分函数(加性和点积注意力) PyTorch 版

发布于:2025-04-12 ⋅ 阅读:(34) ⋅ 点赞:(0)


前言

在深度学习领域,注意力机制(Attention Mechanism)已经成为许多模型的核心组件,尤其是在自然语言处理(NLP)和计算机视觉任务中。注意力机制的核心思想是通过计算查询(Query)与键(Key)之间的相关性,动态地为值(Value)分配权重,从而聚焦于最重要的信息。本篇博客将通过 PyTorch 代码,深入探讨注意力汇聚(Attention Pooling)的两种常见评分函数:加性注意力(Additive Attention)和点积注意力(Dot Product Attention)。我们将从代码实现入手,逐步解析其原理,并通过可视化工具展示注意力权重的分布。
在这里插入图片描述

本文的目标读者是对深度学习有一定基础、希望通过代码理解注意力机制的实现细节的开发者。所有代码均基于 PyTorch,并在 Jupyter Notebook 中运行和测试。让我们开始吧!


一、掩蔽 Softmax 操作

在注意力机制中,掩蔽 Softmax(Masked Softmax)是一个关键步骤,用于确保模型只关注序列中的有效部分,避免对填充(padding)数据产生影响。我们先来看两个核心函数的实现:sequence_maskmasked_softmax

1.1 sequence_mask

sequence_mask 函数用于在序列中屏蔽不相关的项。它接收输入序列张量 X、有效长度张量 valid_len,并将无效位置替换为指定值(默认值为 0)。

import torch
import torch.nn as nn

def sequence_mask(X, valid_len, value=0):
    """
    在序列中屏蔽不相关的项
    
    参数:
        X: 输入序列张量,维度 [batch_size, maxlen]
        valid_len: 有效长度张量,维度 [batch_size]
        value: 填充值,标量,默认为0
    
    返回:
        X: 屏蔽后的序列张量,维度 [batch_size, maxlen]
    
    Defined in :numref:`sec_seq2seq_decoder`
    """
    # 获取序列的最大长度,维度为标量
    maxlen = X.size(1)
    
    # 创建掩码矩阵
    # torch.arange(maxlen): 生成 [0, 1, ..., maxlen-1] 的序列,维度 [maxlen]
    # [None, :] 将其扩展为 [1, maxlen]
    # valid_len[:, None] 将 [batch_size] 扩展为 [batch_size, 1]
    # 比较结果 mask 维度为 [batch_size, maxlen]
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    
    # 使用掩码将 X 中无效位置设为 value
    # ~mask 为反向掩码,选择需要填充的位置
    X[~mask] = value
    
    return X

这个函数的工作原理是:

  1. 通过 torch.arange(maxlen) 生成一个从 0 到 maxlen-1 的序列,并扩展为与批量大小匹配的形状。
  2. 使用广播机制,将 valid_len 与生成的序列比较,生成布尔掩码 mask
  3. 根据掩码,将无效位置(即超出有效长度的部分)替换为 value

1.2 masked_softmax

masked_softmax 函数在 Softmax 操作中加入掩蔽机制,确保无效位置的注意力权重为 0。

def masked_softmax(X, valid_lens):
    """
    通过在最后一个轴上掩蔽元素来执行softmax操作
    
    参数:
        X: 三维张量 (batch_size, seq_len, feature_dim)
        valid_lens: 一维张量 (batch_size,) 或二维张量 (batch_size, seq_len),表示有效长度
        
    返回:
        经过masked softmax处理的张量 (batch_size, seq_len, feature_dim)
    """
    if valid_lens is None:
        # 当没有指定有效长度时,直接执行标准softmax
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape  # shape: (batch_size, seq_len, feature_dim)
        if valid_lens.dim() == 1:
            # 将一维的valid_lens重复扩展到与X的第二维匹配
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将二维的valid_lens展平为一维
            valid_lens = valid_lens.reshape(-1)
            
        # 在最后一轴上对被掩蔽的元素使用非常大的负值替换,使其softmax输出为0
        X = sequence_mask(X.reshape(-1, shape[-1]), 
                          valid_lens,
                          value=-1e6)
        # 执

网站公告

今日签到

点亮在社区的每一天
去签到