文章目录
Flash Attention: 高效注意力机制解析
什么是 Flash Attention?
Flash Attention 是一种针对 Transformer 模型 优化的高效注意力计算方法。与传统注意力机制相比,它通过 分块计算、显存优化 和 数值稳定性改进,实现了在 长序列任务 中的显著加速,同时大幅降低了显存占用。
Flash Attention 与普通 Attention 的对比
特性 | 普通 Attention | Flash Attention |
---|---|---|
计算复杂度 | O ( n 2 ) O(n^2) O(n2),长序列显存占用高 | O ( n 2 ) O(n^2) O(n2),通过分块优化显存使用 |
显存占用 | 必须存储完整的注意力矩阵 n × n n \times n n×n | 分块计算避免存储完整矩阵,显存开销显著降低 |
数值稳定性 | 可能因 Softmax 计算溢出导致不稳定 | 分块归一化(log-sum-exp 技术)保证数值稳定性 |
适用场景 | 适合短序列任务 | 长序列任务的理想选择,如长文档建模、视频建模 |
为什么选择 Flash Attention?
优点
- 显存高效:避免存储完整的注意力矩阵,支持更长的序列处理。
- 计算快速:使用分块和 CUDA 优化,比普通 Attention 加速 2-4 倍。
- 数值稳定:改进 Softmax 的实现,支持更大的输入范围。
- 适合长序列任务:如 NLP 长文档处理、生物信息学蛋白质序列建模、高分辨率视频分析。
局限性
- 实现复杂:依赖 CUDA 核心优化,难以手动实现完整功能。
- 硬件要求高:需要现代 GPU 和高效的内存管理。
Flash Attention 的工作原理
核心机制
传统公式:
Attention ( Q , K , V ) = Softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=Softmax(dkQKT)VFlash Attention 的优化:
- 分块计算:避免存储完整的 n × n n \times n n×n 矩阵。
- 块内归一化:
Softmax ( x ) = exp ( x − max ( x ) ) ∑ exp ( x − max ( x ) ) \text{Softmax}(x) = \frac{\exp(x - \max(x))}{\sum \exp(x - \max(x))} Softmax(x)=∑exp(x−max(x))exp(x−max(x)) - CUDA 并行化:结合 kernel fusion 实现高效矩阵运算。
Flash Attention 实现代码
普通 Attention 示例
import torch
def attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, V)
return output
Flash Attention 示例(简化)
def flash_attention(Q, K, V, block_size=32):
batch_size, seq_len, hidden_dim = Q.size()
d_k = hidden_dim
output = torch.zeros_like(Q)
for i in range(0, seq_len, block_size):
for j in range(0, seq_len, block_size):
Q_block = Q[:, i:i+block_size, :]
K_block = K[:, j:j+block_size, :]
V_block = V[:, j:j+block_size, :]
scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
max_scores = torch.max(scores, dim=-1, keepdim=True)[0]
scores = scores - max_scores
attention = torch.exp(scores)
attention = attention / torch.sum(attention, dim=-1, keepdim=True)
output[:, i:i+block_size, :] += torch.matmul(attention, V_block)
return output
Flash Attention 的工作原理展示
Flash Attention 的核心优化在于 分块计算(Blockwise Attention Calculation),通过分块减少显存占用,并保持计算效率和数值稳定性。以下是 Flash Attention 的工作流程及分块计算的具体实现细节:
核心流程
Flash Attention 的实现主要分为以下几步:
输入序列分块:
- 将输入的
Q
、K
和V
分成小块(block_size
),避免一次性计算完整的注意力矩阵。 - 每个块分别计算局部的点积、Softmax 和加权结果。
- 将输入的
块内注意力计算:
- 对每个块内计算注意力分布,使用数值稳定的 Softmax 优化,避免数值溢出问题。
逐块累积输出:
- 将分块结果逐步累积,得到最终的全局注意力输出。
示例:分块计算
输入矩阵
假设:
- 输入矩阵
Q
(Query):形状为4 × 4
,表示序列长度为 4,隐藏维度为 4。 - 输入矩阵
K
(Key):形状为4 × 4
,与Q
的形状一致。 - 输入矩阵
V
(Value):形状为4 × 4
,与Q
和K
的形状一致。 - 分块大小
block_size
:假设为 2,表示每次处理 2 个序列块。
示例输入矩阵
Q = [[1, 2, 3, 4],
[4, 3, 2, 1],
[1, 1, 1, 1],
[2, 2, 2, 2]]
K = [[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 1, 1, 1],
[2, 2, 2, 2]]
V = [[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]]
步骤 1:分块
将 Q 和 K 按行进行分块,每块大小为 block_size=2:
Q 分块:
Q_1 = [[1, 2, 3, 4],
[4, 3, 2, 1]]
Q_2 = [[1, 1, 1, 1],
[2, 2, 2, 2]]
K 分块:
K_1 = [[1, 0, 1, 0],
[0, 1, 0, 1]]
K_2 = [[1, 1, 1, 1],
[2, 2, 2, 2]]
V 分块:
V_1 = [[1, 1, 1, 1],
[2, 2, 2, 2]]
V_2 = [[3, 3, 3, 3],
[4, 4, 4, 4]]
步骤 2:块间点积计算
计算每个块的点积 Q_block × K_block^T,并缩放:
计算 ( Q_1 \times K_1^T ):
Q_1 × K_1^T = [[1, 2, 3, 4] × [[1, 0, 1, 0]^T = [[15, 14],
[4, 3, 2, 1]] [0, 1, 0, 1]] [10, 20]]
缩放结果(假设隐藏维度 ( d_k = 4 ),缩放因子为 ( \sqrt{4} = 2 )):
Scores = [[15 / 2, 14 / 2],
[10 / 2, 20 / 2]] = [[7.5, 7.0],
[5.0, 10.0]]
计算 ( Q_1 \times K_2^T ):
Q_1 × K_2^T = [[1, 2, 3, 4] × [[1, 1, 1, 1]^T = [[30, 60],
[4, 3, 2, 1]] [2, 2, 2, 2]] [20, 40]]
缩放结果:
Scores = [[30 / 2, 60 / 2],
[20 / 2, 40 / 2]] = [[15.0, 30.0],
[10.0, 20.0]]
步骤 3:数值稳定的 Softmax 计算
使用 最大值减法 技术对每个分块的 Scores 计算 Softmax,避免数值溢出。
对 ( Q_1 \times K_1^T ) 的 Scores 计算:
Scores = [[7.5, 7.0],
[5.0, 10.0]]
最大值减法:[[7.5 - 7.5, 7.0 - 7.5],
[5.0 - 10.0, 10.0 - 10.0]] = [[0, -0.5],
[-5, 0]]
指数计算:[[exp(0), exp(-0.5)],
[exp(-5), exp(0)]] = [[1.0, 0.6065],
[0.0067, 1.0]]
Softmax 归一化:[[1.0 / (1.0 + 0.6065), 0.6065 / (1.0 + 0.6065)],
[0.0067 / (0.0067 + 1.0), 1.0 / (0.0067 + 1.0)]] = [[0.622, 0.378],
[0.007, 0.993]]
步骤 4:加权输出计算
使用 Softmax 权重和 V 的块计算加权输出。
( Q_1 ) 的加权结果:
Output_1 = Softmax(Q_1 × K_1^T) × V_1 = [[0.622, 0.378] × [[1, 1, 1, 1],
[2, 2, 2, 2]] [2.622, 2.622, 2.622, 2.622]]
类似地,依次计算 ( Q_2 × K_2^T ) 的输出,逐块累积所有块的结果。
最终输出
将所有分块的输出累加到最终结果矩阵中,得到完整的注意力结果矩阵 Output。
Flash Attention 的优势
1. 显存优化:
• 普通 Attention 需要存储完整的注意力矩阵,显存占用为 ( O(n^2) )。
• Flash Attention 仅存储分块结果,显存占用为 ( O(n \cdot \text{block_size}) )。
2. 计算效率:
• 分块计算可以并行化,结合 CUDA 核心优化,速度显著提高。
3. 数值稳定性:
• 使用块级 Softmax 和最大值归一化,避免长序列点积的数值溢出问题。