【科研日常】线性注意力(Linear Attention)学习笔记

发布于:2025-08-16 ⋅ 阅读:(9) ⋅ 点赞:(0)

线性注意力(Linear Attention)

线性注意力是一种改进的注意力机制,旨在解决传统自注意力(Self-Attention)在处理长序列时计算和内存复杂度过高的问题。传统自注意力的计算复杂度是 O(N2)O(N^2)O(N2),而线性注意力通过一系列数学技巧将其降低到 O(N)O(N)O(N),更适合处理长文本或高分辨率图像。

🔍 一、什么是线性注意力?为什么是线性?

在标准自注意力中,核心操作如下:

Attention(Q,K,V)=softmax(QKTd)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V Attention(Q,K,V)=softmax(d QKT)V

其中:

  • Q,K,VQ, K, VQ,K,V 是查询(Query)、键(Key)、值(Value)矩阵,维度都是 N×dN \times dN×d
  • 这个公式涉及到计算 QKTQK^TQKT,它是一个 N×NN \times NN×N 的矩阵 —— 所以计算复杂度是 O(N2)O(N^2)O(N2)

而线性注意力的基本想法是:将 softmax 或其他非线性函数 近似替换成内积可以拆分的形式,也就是说将其写成:

Attention(Q,K,V)≈(ϕ(Q)(ϕ(K)TV)) \text{Attention}(Q, K, V) \approx (\phi(Q) (\phi(K)^T V)) Attention(Q,K,V)(ϕ(Q)(ϕ(K)TV))

其中 ϕ(⋅)\phi(\cdot)ϕ() 是一个正定映射函数(如 ReLU、ELU、exp、kernel trick等),可以将原本的 QKTQK^TQKT 操作变成:

  1. 先计算 KTVK^T VKTV,维度是 d×dd \times dd×d,成本是 O(Nd2)O(Nd^2)O(Nd2)
  2. 然后再乘上 ϕ(Q)\phi(Q)ϕ(Q),成本是 O(Nd2)O(Nd^2)O(Nd2)

从而整体注意力过程从原本的 O(N2d)O(N^2d)O(N2d) 降为 O(Nd2)O(Nd^2)O(Nd2),在 N≫dN \gg dNd 的时候是巨大的提升。

🧠 二、与传统注意力的区别

项目 传统自注意力 线性注意力
复杂度 O(N2)O(N^2)O(N2) O(N)O(N)O(N)
可并行性
对长序列的支持 差(容易OOM)
精度 略有损失(视实现而定)
使用softmax 否(用可拆解核函数近似)

✅ 三、实际应用场景

线性注意力已经应用于以下几个方向:

  • 长文档生成(如 GPT-类模型压缩长上下文)
  • 高分辨率图像建模(Vision Transformers)
  • 语音识别与语音建模
  • 嵌入设备(低功耗硬件)上运行 Transformer

典型模型包括:

  • Performer
  • Linformer
  • Linear Transformers(本文提到的 ICML 2020 论文)

📊 四、可视化比较

标准注意力计算流程:

        Q (N x d)
         |
         v
        QK^T (N x N)  <---- K (N x d)
         |
         v
      softmax
         |
         v
      Weighted sum with V (N x d)

线性注意力计算流程:

      φ(Q) (N x d)
         |
         v
     φ(K)^T V  (d x d)  <--- φ(K) (N x d), V (N x d)
         |
         v
     Final output = φ(Q) × (φ(K)^T V)

不再显式构造 N×NN \times NN×N 的中间矩阵,节省空间和计算。

🔧 五、Python代码对比(PyTorch)

✅ 标准注意力代码(简化版):

import torch
import torch.nn.functional as F

def standard_attention(Q, K, V):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / d_k**0.5
    attn = F.softmax(scores, dim=-1)
    return torch.matmul(attn, V)

✅ 线性注意力代码(基于ELU核):

def elu_feature_map(x):
    return F.elu(x) + 1  # 保证非负

def linear_attention(Q, K, V):
    Q_prime = elu_feature_map(Q)  # φ(Q)
    K_prime = elu_feature_map(K)  # φ(K)

    KV = torch.einsum('nld,nle->lde', K_prime, V)  # φ(K)^T V
    Z = 1 / (torch.einsum('nld,ld->nl', Q_prime, K_prime.sum(dim=0)) + 1e-6)  # 正则项
    output = torch.einsum('nld,lde,nl->nle', Q_prime, KV, Z)  # 最终输出
    return output

🧾 六、总结

  • 线性注意力本质上是通过数学变换将注意力矩阵“内积核”结构分解,避免显式计算 N×NN \times NN×N 的矩阵。
  • 它能显著提高模型效率和可扩展性,特别适用于长序列任务。
  • 虽然可能略微损失精度,但对于许多工程实际来说是可以接受的。

网站公告

今日签到

点亮在社区的每一天
去签到