多种注意力机制(文本->残差->视频)

发布于:2025-03-14 ⋅ 阅读:(13) ⋅ 点赞:(0)

1.初代自我注意机制(多头注意力机制)

1.1原理

总体架构

 

上图是 Self-Attention 的结构,在计算的时候需要用到矩阵Q(查询),K(键值),V(值)。在实际中,Self-Attention 接收的是输入(单词的表示向量x组成的矩阵X) 或者上一个 Encoder block 的输出。而Q,K,V正是通过 Self-Attention 的输入进行线性变换得到的。

Q, K, V 的计算  这个X无疑是由不同单词(词向量)组成的

如下图无疑是计算出了Q K V

 多头注意力的输出

 公式中计算矩阵QK每一行向量的内积,为了防止内积过大,因此除以 d_{k}的平方根。Q乘以K的转置后,得到的矩阵行列数都为 n,n 为句子单词数,这个矩阵可以表示单词之间的 attention 强度

换一话来说 对图像而言 是不是patch 对视频而言 是不是帧)!!!

QK^{T}之后,使用 Softmax 计算每一个单词对于其他单词的 attention 系数,公式中的 Softmax 是对矩阵的每一行进行 Softmax,即每一行的和都变为 1. 

 得到 Softmax 矩阵之后可以和V相乘,得到最终的输出Z

就拿结果而言我们是不是得到了一个带有位置信息 和对所有样本有注意力系数的词向量 (可以一个patch 也可以是一个帧)

多头注意力

在上一步,我们已经知道怎么通过 Self-Attention 计算得到输出矩阵 Z,而 Multi-Head Attention 是由多个 Self-Attention 组合形成的,下图是论文中 Multi-Head Attention 的结构图。

 从上图可以看到 Multi-Head Attention 包含多个 Self-Attention 层,首先将输入X分别传递到 h 个不同的 Self-Attention 中,计算得到 h 个输出矩阵Z。下图是 h=8 时候的情况,此时会得到 8 个输出矩阵

 得到 8 个输出矩阵 Z1 到 Z8 之后,Multi-Head Attention 将它们拼接在一起 (Concat),然后传入一个Linear层,得到 Multi-Head Attent

最终的输出Z 

掩码注意力就是看不见未来的情况

参考这一篇transform

代码实现

调用api

nn.MultiheadAttention(d_model, n_head)

 def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

2.残差注意力机制(图像识别用的多)

工作过程

  • 特征提取:首先,输入数据经过一系列的卷积层和池化层等操作进行特征提取,得到初始的特征图。

  • 注意力计算:将初始特征图输入到注意力模块中,通过一系列的计算生成注意力掩码。这个过程通常包括对特征图进行卷积操作、激活函数处理以及归一化等步骤,以得到每个位置的注意力权重。

  • 特征加权:将注意力掩码与初始特征图相乘,使得特征图中重要的部分得到增强,不重要的部分得到抑制。

  • 残差融合:将经过特征加权后的特征图与原始输入特征图通过残差连接相加,得到最终的输出特征图。这个输出特征图既包含了经过注意力机制筛选后的重要特征,又保留了原始输入中的信息,从而提高了特征的质量和模型的性能。

代码实现

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = nn.LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

3.跨帧注意力块(视频用的多)

跨帧融合注意力(CFA)

CFA 的作用是允许帧之间直接交换信息。具体来说:

  • 每个帧都有一个 消息令牌(message token),用于抽象当前帧的视觉信息。

  • 这些消息令牌会参与到全局的时间依赖性建模中。

  • 消息令牌的计算方式如下:

m_t^{(l)} = W_m \cdot z_t^{(l-1), [class]}

 m_t^{(l)}为t帧在l层的消息令牌 w是一个线性矩阵,z_t^{(l-1), [class]}是前一层的特征输出。

所有帧的消息令牌会参与到跨帧融合注意力中

\hat{M}^{(l)} = M^{(l)} + \text{CFA}(\text{LN}(M^{(l)}))

CFA 是一个自注意力模块,它允许消息令牌之间相互作用,从而捕捉帧之间的时间依赖性

帧内融合注意力(IFA)

IFA 的作用是在帧内传播全局的时间信息,增强帧的表示能力

[\hat{z}_t^{(l)}, \bar{m}_t^{(l)}] = [z_t^{(l-1)}, \hat{m}_t^{(l)}] + \text{IFA}(\text{LN}([z_t^{(l-1)}, \hat{m}_t^{(l)}]))

  • 其中,\hat{z}_t^{(l)}是帧特征的更新表示,mˉt(l)​ 是消息令牌的更新表示。

  • 消息令牌在每个块中仅用于信息交换,不会传递到下一个块。

CFA

首先一个数据由(b,t,c.h,w)->(bt,c,h,w)->(wxh,bt,d)(你得了解这个视频处理)

 l, bt, d = x.size()

然后为了得到时间 一个知识点 三维是图像块输入 四维是帧输入

b = bt // self.T
x = x.view(l, b, self.T, d) 

初始化消息令牌(vit的csl)


msg_token = self.message_fc(x[0,:,:,:]) # # 使用线性变换初始化消息令牌 使用线性变换初始化消#息令牌 这个代表帧在不同批次和步长
msg_token = msg_token.view(b, self.T, 1, d) # 调整形状为 (b, T, 1, d)
msg_token = msg_token.permute(1,2,0,3).view(self.T, b, d)

 实现跨帧传输(t不同无疑就是 不同帧)

msg_token = msg_token + self.drop_path(self.message_attn(self.message_ln(msg_token),self.message_ln(msg_token),self.message_ln(msg_token),need_weights=False)[0])

 IFA:帧内

 x = x.view(l+1, -1, d)
        x = x + self.drop_path(self.attention(self.ln_1(x)))
        x = x[:l,:,:]

4.补充一个知识点

MultiHeadAttention

 传入一个输入的多头注意力机制:这种形式可以让模型关注序列内部的上下文信息,捕捉序列中不同位置之间的依赖关系。例如在文本处理中,一个句子中的每个词都可以通过自注意力机制与其他词进行交互,从而更好地理解整个句子的语义。

传入三个输入的多头注意力机制:形式允许模型在不同的表示子空间中并行地计算注意力,从而捕捉输入序列之间的不同类型的依赖关系。通过分别传入 Q、K、V,模型可以灵活地根据查询去关注键值对中的不同部分,提高模型的表达能力。