Flash Attention

发布于:2024-12-20 ⋅ 阅读:(17) ⋅ 点赞:(0)

Flash Attention: 高效注意力机制解析

什么是 Flash Attention?

Flash Attention 是一种针对 Transformer 模型 优化的高效注意力计算方法。与传统注意力机制相比,它通过 分块计算显存优化数值稳定性改进,实现了在 长序列任务 中的显著加速,同时大幅降低了显存占用。


Flash Attention 与普通 Attention 的对比

特性 普通 Attention Flash Attention
计算复杂度 O ( n 2 ) O(n^2) O(n2),长序列显存占用高 O ( n 2 ) O(n^2) O(n2),通过分块优化显存使用
显存占用 必须存储完整的注意力矩阵 n × n n \times n n×n 分块计算避免存储完整矩阵,显存开销显著降低
数值稳定性 可能因 Softmax 计算溢出导致不稳定 分块归一化(log-sum-exp 技术)保证数值稳定性
适用场景 适合短序列任务 长序列任务的理想选择,如长文档建模、视频建模

为什么选择 Flash Attention?

优点

  1. 显存高效:避免存储完整的注意力矩阵,支持更长的序列处理。
  2. 计算快速:使用分块和 CUDA 优化,比普通 Attention 加速 2-4 倍。
  3. 数值稳定:改进 Softmax 的实现,支持更大的输入范围。
  4. 适合长序列任务:如 NLP 长文档处理、生物信息学蛋白质序列建模、高分辨率视频分析。

局限性

  1. 实现复杂:依赖 CUDA 核心优化,难以手动实现完整功能。
  2. 硬件要求高:需要现代 GPU 和高效的内存管理。

Flash Attention 的工作原理

核心机制

  1. 传统公式
    Attention ( Q , K , V ) = Softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=Softmax(dk QKT)V

  2. Flash Attention 的优化

    • 分块计算:避免存储完整的 n × n n \times n n×n 矩阵。
    • 块内归一化
      Softmax ( x ) = exp ⁡ ( x − max ⁡ ( x ) ) ∑ exp ⁡ ( x − max ⁡ ( x ) ) \text{Softmax}(x) = \frac{\exp(x - \max(x))}{\sum \exp(x - \max(x))} Softmax(x)=exp(xmax(x))exp(xmax(x))
    • CUDA 并行化:结合 kernel fusion 实现高效矩阵运算。

Flash Attention 实现代码

普通 Attention 示例

import torch

def attention(Q, K, V):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
    attention = torch.softmax(scores, dim=-1)
    output = torch.matmul(attention, V)
    return output

Flash Attention 示例(简化)

def flash_attention(Q, K, V, block_size=32):
    batch_size, seq_len, hidden_dim = Q.size()
    d_k = hidden_dim
    output = torch.zeros_like(Q)

    for i in range(0, seq_len, block_size):
        for j in range(0, seq_len, block_size):
            Q_block = Q[:, i:i+block_size, :]
            K_block = K[:, j:j+block_size, :]
            V_block = V[:, j:j+block_size, :]

            scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
            max_scores = torch.max(scores, dim=-1, keepdim=True)[0]
            scores = scores - max_scores
            attention = torch.exp(scores)
            attention = attention / torch.sum(attention, dim=-1, keepdim=True)

            output[:, i:i+block_size, :] += torch.matmul(attention, V_block)

    return output

Flash Attention 的工作原理展示

Flash Attention 的核心优化在于 分块计算(Blockwise Attention Calculation),通过分块减少显存占用,并保持计算效率和数值稳定性。以下是 Flash Attention 的工作流程及分块计算的具体实现细节:


核心流程

Flash Attention 的实现主要分为以下几步:

  1. 输入序列分块

    • 将输入的 QKV 分成小块(block_size),避免一次性计算完整的注意力矩阵。
    • 每个块分别计算局部的点积、Softmax 和加权结果。
  2. 块内注意力计算

    • 对每个块内计算注意力分布,使用数值稳定的 Softmax 优化,避免数值溢出问题。
  3. 逐块累积输出

    • 将分块结果逐步累积,得到最终的全局注意力输出。

示例:分块计算

输入矩阵

假设:

  1. 输入矩阵 Q(Query):形状为 4 × 4,表示序列长度为 4,隐藏维度为 4。
  2. 输入矩阵 K(Key):形状为 4 × 4,与 Q 的形状一致。
  3. 输入矩阵 V(Value):形状为 4 × 4,与 QK 的形状一致。
  4. 分块大小 block_size:假设为 2,表示每次处理 2 个序列块。
示例输入矩阵
Q = [[1, 2, 3, 4],
     [4, 3, 2, 1],
     [1, 1, 1, 1],
     [2, 2, 2, 2]]

K = [[1, 0, 1, 0],
     [0, 1, 0, 1],
     [1, 1, 1, 1],
     [2, 2, 2, 2]]

V = [[1, 1, 1, 1],
     [2, 2, 2, 2],
     [3, 3, 3, 3],
     [4, 4, 4, 4]]

步骤 1:分块

将 Q 和 K 按行进行分块,每块大小为 block_size=2:

Q 分块:
Q_1 = [[1, 2, 3, 4],
       [4, 3, 2, 1]]

Q_2 = [[1, 1, 1, 1],
       [2, 2, 2, 2]]

K 分块:
K_1 = [[1, 0, 1, 0],
       [0, 1, 0, 1]]

K_2 = [[1, 1, 1, 1],
       [2, 2, 2, 2]]

V 分块:
V_1 = [[1, 1, 1, 1],
       [2, 2, 2, 2]]

V_2 = [[3, 3, 3, 3],
       [4, 4, 4, 4]]

步骤 2:块间点积计算

计算每个块的点积 Q_block × K_block^T,并缩放:

计算 ( Q_1 \times K_1^T ):

Q_1 × K_1^T = [[1, 2, 3, 4]   ×   [[1, 0, 1, 0]^T    = [[15, 14],
                    [4, 3, 2, 1]]        [0, 1, 0, 1]]     [10, 20]]

缩放结果(假设隐藏维度 ( d_k = 4 ),缩放因子为 ( \sqrt{4} = 2 )):

Scores = [[15 / 2, 14 / 2],
          [10 / 2, 20 / 2]] = [[7.5, 7.0],
                               [5.0, 10.0]]

计算 ( Q_1 \times K_2^T ):

Q_1 × K_2^T = [[1, 2, 3, 4]   ×   [[1, 1, 1, 1]^T    = [[30, 60],
                    [4, 3, 2, 1]]        [2, 2, 2, 2]]     [20, 40]]

缩放结果:

Scores = [[30 / 2, 60 / 2],
          [20 / 2, 40 / 2]] = [[15.0, 30.0],
                               [10.0, 20.0]]

步骤 3:数值稳定的 Softmax 计算

使用 最大值减法 技术对每个分块的 Scores 计算 Softmax,避免数值溢出。

对 ( Q_1 \times K_1^T ) 的 Scores 计算:

Scores = [[7.5, 7.0],
          [5.0, 10.0]]

最大值减法:[[7.5 - 7.5, 7.0 - 7.5],
           [5.0 - 10.0, 10.0 - 10.0]] = [[0, -0.5],
                                        [-5,  0]]

指数计算:[[exp(0), exp(-0.5)],
          [exp(-5), exp(0)]] = [[1.0, 0.6065],
                                [0.0067, 1.0]]

Softmax 归一化:[[1.0 / (1.0 + 0.6065), 0.6065 / (1.0 + 0.6065)],
               [0.0067 / (0.0067 + 1.0), 1.0 / (0.0067 + 1.0)]] = [[0.622, 0.378],
                                                                [0.007, 0.993]]

步骤 4:加权输出计算

使用 Softmax 权重和 V 的块计算加权输出。

( Q_1 ) 的加权结果:

Output_1 = Softmax(Q_1 × K_1^T) × V_1 = [[0.622, 0.378]   ×   [[1, 1, 1, 1],
                                           [2, 2, 2, 2]]     [2.622, 2.622, 2.622, 2.622]]

类似地,依次计算 ( Q_2 × K_2^T ) 的输出,逐块累积所有块的结果。

最终输出

将所有分块的输出累加到最终结果矩阵中,得到完整的注意力结果矩阵 Output。

Flash Attention 的优势

	1.	显存优化:
	•	普通 Attention 需要存储完整的注意力矩阵,显存占用为 ( O(n^2) )。
	•	Flash Attention 仅存储分块结果,显存占用为 ( O(n \cdot \text{block_size}) )。
	2.	计算效率:
	•	分块计算可以并行化,结合 CUDA 核心优化,速度显著提高。
	3.	数值稳定性:
	•	使用块级 Softmax 和最大值归一化,避免长序列点积的数值溢出问题。