Native Sparse Attention: Hardware-Aligned and NativelyTrainable Sparse Attention
原生稀疏注意力:与硬件对齐且可原生训练的稀疏注意力
原文地址:官方论文地址
摘要
长上下文建模对于下一代语言模型至关重要,但标准注意力机制的高计算成本带来了显著的计算挑战。稀疏注意力在提高效率的同时保持模型能力提供了一种有前景的方向。我们提出了 NSA(原生可训练稀疏注意力,Natively trainable Sparse Attention),该机制结合了算法创新与硬件对齐优化,实现了高效的长上下文建模。NSA 采用 动态分层稀疏策略,结合 粗粒度的 token 压缩 和 细粒度的 token 选择,在保持全局上下文感知的同时确保局部精确性。
我们的方法在稀疏注意力设计方面引入了两个关键创新:
(1) 通过算术强度均衡的算法设计实现显著的加速,并针对现代硬件进行了优化实现。
(2) 支持端到端训练,降低预训练计算成本,而不会牺牲模型性能。
如 图 1 所示,实验表明,使用 NSA 预训练的模型在通用基准测试、长上下文任务和基于指令的推理任务中,性能可与 全注意力(Full Attention)模型 相媲美,甚至超越。同时,在 64k 长度序列上的解码、前向传播和反向传播过程中,NSA 相比Full Attention实现了大幅加速,验证了其在整个模型生命周期中的高效性。
图 1 | 全注意力模型与 NSA 在性能和效率上的对比
左图:尽管 NSA 采用稀疏注意力,但在通用基准测试、长上下文任务和推理评估中,其平均性能超过了完整注意力基线模型。
右图:在 64k 长度序列 处理任务中,NSA 在解码、前向传播和反向传播的所有阶段相较于全注意力实现了显著的计算加速。
1. 引言
研究社区越来越认识到 长上下文建模是下一代大语言模型的重要能力,这一需求由多种现实世界应用所驱动,包括深度推理(DeepSeek-AI, 2025; Zelikman et al., 2022)、仓库级代码生成(Zhang et al., 2023a; Zhang et al.)以及多轮自主智能体系统(Park et al., 2023)。近期的突破性进展,如 OpenAI 的 o-series 模型、DeepSeek-R1(DeepSeek-AI, 2025)以及 Gemini 1.5 Pro(Google et al., 2024),使得模型能够处理完整的代码库和长篇文档,在数千个 token 级别保持连贯的多轮对话,并在长距离依赖关系中执行复杂推理。然而,随着序列长度的增加,标准注意力机制(vanilla Attention,Vaswani et al., 2017)的高计算复杂度(Zaheer et al., 2020)成为关键的延迟瓶颈。理论估算表明,在 64k 长度的上下文解码过程中,基于 Softmax 的注意力计算占总延迟的 70–80%,这凸显了开发更高效注意力机制的迫切需求。
一种自然的高效长上下文建模方法是利用 Softmax 注意力的固有稀疏性(Ge et al., 2023; Jiang et al., 2023),即仅选择性地计算关键的 Query-Key 对,从而在保持性能的同时显著降低计算开销。近年来,多个研究方向展示了这一思路的潜力,包括:
- KV 缓存淘汰方法(KV-cache eviction)(Li et al., 2024; Zhang et al., 2023b; Zhou et al., 2024)
- 块级 KV 缓存选择方法(blockwise KV-cache selection)(Tang et al., 2024; Xiao et al., 2024)
- 基于采样、聚类或哈希的选择方法(sampling, clustering, or hashing-based selection)(Chen et al., 2024; Desai et al., 2024; Liu et al., 2024)
尽管这些策略表现出良好的潜力,但现有的稀疏注意力方法在实际部署中仍存在不足:
- 许多方法的实际加速效果远低于理论预期,无法在现代硬件上充分发挥其计算优势。
- 大多数方法仅关注推理阶段,缺乏有效的训练时支持,无法充分利用注意力的稀疏模式来优化整体训练效率。
为了解决这些局限性,部署高效的稀疏注意力必须应对两个关键挑战:
(1) 与硬件对齐的推理加速:将理论上的计算减少转化为实际的速度提升,需要在预填充(prefilling)和解码(decoding)阶段进行硬件友好的算法设计,以缓解内存访问和硬件调度瓶颈。
(2) 面向训练的算法设计:支持端到端计算,使用可训练算子来降低训练成本,同时保持模型性能。
这些需求对于实现高效的长上下文推理或训练至关重要。然而,在同时考虑这两个方面时,现有方法仍然存在明显的差距。
为了实现更高效、更有效的稀疏注意力机制,我们提出了NSA(Natively trainable Sparse Attention),一种原生可训练的稀疏注意力架构,结合了分层的token建模。如图2所示,NSA 通过将键(keys)和值(values)组织成时间块(temporal blocks),并通过三条注意力路径进行处理,从而减少每个查询(query)的计算量:压缩的粗粒度token、选择性保留的细粒度token,以及用于局部上下文信息的滑动窗口(sliding windows)。随后,我们实现了专门的计算核(kernels)以最大化其实用效率。NSA 针对上述关键需求引入了两项核心创新:
(1) 与硬件对齐的系统:优化基于块的稀疏注意力,以充分利用Tensor Core并优化内存访问,确保平衡的算术强度(arithmetic intensity)。
(2) 面向训练的设计:通过高效算法和反向传播算子,实现稳定的端到端训练。
这种优化使NSA既能支持高效部署,又能实现端到端训练。
我们通过在真实世界的语言语料库上进行全面实验来评估 NSA。在一个拥有 270 亿参数的 Transformer 主干模型上进行预训练,使用 2600 亿个 token,我们评估 NSA 在通用语言任务、长上下文任务以及思维链(chain-of-thought)推理任务中的表现。此外,我们在 A100 GPU 上使用优化的 Triton(Tillet et al., 2019)实现,对比 NSA 的计算核速度。实验结果表明,NSA 的性能与全注意力(Full Attention)基线相当或更优,同时优于现有的稀疏注意力方法。此外,与全注意力相比,NSA 在解码、前向传播和反向传播阶段均实现了显著的加速,且随着序列长度的增加,加速比进一步提高。这些结果验证了我们的分层稀疏注意力设计在模型能力和计算效率之间实现了有效的平衡。
图 2 | NSA 架构概览。
左图:该框架通过三个并行注意力分支处理输入序列。对于给定的查询,先前的键和值被处理为压缩注意力(用于粗粒度模式)、选择性注意力(用于重要的 token 块)以及滑动注意力(用于局部上下文)。
右图:不同注意力分支生成的注意力模式可视化。绿色区域表示需要计算注意力分数的区域,而白色区域表示可以跳过的区域。
2. 重新思考稀疏注意力方法
现代稀疏注意力方法在降低 Transformer 模型的理论计算复杂度方面取得了显著进展。然而, 大多数方法主要在推理阶段应用稀疏性,同时仍然保留预训练的全注意力(Full Attention)主干网络,这可能引入架构偏差,从而限制其充分利用稀疏注意力优势的能力。在介绍我们的原生稀疏架构之前,我们从两个关键角度系统性地分析这些方法的局限性。
2.1. 高效推理的幻觉
尽管许多方法在注意力计算中实现了稀疏性,但它们往往未能在推理延迟上获得相应的降低,主要受到以下两个挑战的影响:
阶段受限的稀疏性
诸如 H2O(Zhang et al., 2023b)的方法在自回归解码(autoregressive decoding)阶段应用稀疏性,但在预填充(prefilling)阶段需要进行计算密集型的预处理(如注意力映射计算、索引构建)。相比之下,MInference(Jiang et al., 2024)等方法仅专注于预填充阶段的稀疏性。这些方法无法在所有推理阶段实现加速,因为至少有一个阶段的计算成本仍然接近全注意力(Full Attention)。这种阶段特化(phase specialization)限制了这些方法在以预填充为主的任务(如书籍摘要、代码补全)或以解码为主的任务(如长链式推理(Wei et al., 2022))中的加速能力。
与先进注意力架构的不兼容性
某些稀疏注意力方法难以适应现代高效解码架构,例如 多查询注意力(Multiple-Query Attention, MQA)(Shazeer, 2019)和 分组查询注意力(Grouped-Query Attention, GQA)(Ainslie et al., 2023)。这些架构通过在多个查询头(query heads)之间共享 KV 记忆(key-value cache),显著减少了解码阶段的内存访问瓶颈。例如,在 Quest(Tang et al., 2024)等方法中,每个注意力头独立选择其 KV 缓存子集(KV-cache subset)。尽管这种方法在多头注意力(Multi-Head Attention, MHA)模型中能保持计算稀疏性和内存访问稀疏性,但在基于 GQA 之类的架构中,KV 缓存的内存访问量取决于 同一 GQA 组内所有查询头的选择集合的并集。这种架构特性意味着,尽管这些方法可以减少计算操作,但所需的 KV 缓存内存访问仍然较高。这一局限性导致一个关键的选择:尽管某些稀疏注意力方法减少了计算量,但其 分散的内存访问模式 与现代架构优化的高效内存访问设计相冲突。
这些限制的根源在于,许多现有的稀疏注意力方法主要关注 KV 缓存缩减 或 理论计算量降低,但在先进的计算框架或后端(backend)中难以实现显著的延迟优化。因此,我们的目标是开发一种算法,结合先进的架构设计和高效的硬件实现,以充分利用稀疏性来提升模型效率。
2.2. 可训练稀疏性之谜团
我们的原生可训练稀疏注意力(NSA)旨在解决从推理优先的方法中得出的两个关键问题:
性能下降。 在预训练后强行施加稀疏性会迫使模型偏离其优化轨迹。例如,Chen et al. (2024) 研究表明,前 20% 的注意力仅覆盖 70% 的总注意力分数,这使得预训练模型中的检索头等结构在推理时容易被剪枝。
训练效率需求。 现代大规模语言模型的训练需要高效处理长序列任务,包括更长文档的预训练,以提升模型能力,以及后续的长上下文微调和强化学习阶段。然而,现有稀疏注意力方法主要针对推理优化,而训练阶段的计算开销问题仍未解决,阻碍了更高效的长上下文模型发展。
此外,尝试将现有稀疏注意力方法用于训练会暴露出以下挑战:
非可训练组件。 诸如 ClusterKV (Liu et al., 2024) 的 k-means 聚类和 MagicPIG (Chen et al., 2024) 的 SimHash 选择等方法引入了离散操作,使计算图出现断裂,阻碍梯度在 token 选择过程中传播,限制了模型学习最优稀疏模式的能力。
低效的反向传播。 某些可训练的稀疏注意力方法在实践中仍然效率低下。例如,HashAttention (Desai et al., 2024) 采用基于 token 的选择策略,在注意力计算过程中需要加载大量非连续的 KV token,这种非连续的内存访问模式阻碍了 FlashAttention 等高效注意力技术的适配,而这些技术依赖于连续内存访问和块级计算以实现高吞吐量。因此,实际实现中往往不得不回退到低效的硬件利用,显著降低训练效率。
2.3. 原生稀疏性的必要性
推理效率和训练可行性方面的局限性促使我们从根本上重新设计稀疏注意力机制。我们提出 NSA(Natively Sparse Attention),一个原生稀疏注意力框架,旨在同时满足计算效率和训练需求。在接下来的章节中,我们将详细介绍 NSA 的算法设计和算子实现。
3. 方法论
我们的技术方法涵盖算法设计和内核优化。在接下来的小节中,我们首先介绍方法论的背景。然后,我们呈现 NSA 的整体框架,并介绍其关键算法组件。最后,我们详细阐述经过硬件优化的内核设计,以最大化其实用效率。
3.1. 背景
注意力机制 广泛应用于语言建模,其中每个查询 token q t q_t qt 计算与所有先前键 k : t k_{:t} k:t的相关性分数,以生成值 v : t v_{:t} v:t的加权和。形式上,对于长度为 t t t的输入序列,注意力操作定义如下:
o t = Attn ( q t , k : t , v : t ) o_t = \text{Attn} \left( q_t, k_{:t}, v_{:t} \right) ot=Attn(qt,k:t,v:t)
其中,Attn 表示注意力函数:
Attn ( q t , k : t , v : t ) = ∑ i = 1 t α t , i v i ∑ j = 1 t α t , j , α t , i = e q t T k i d k . \text{Attn} \left( q_t, k_{:t}, v_{:t} \right) = \sum_{i=1}^{t} \frac{\alpha_{t,i} v_i}{\sum_{j=1}^{t} \alpha_{t,j}}, \quad \alpha_{t,i} = e^{\frac{q_t^T k_i}{\sqrt{d_k}}}. Attn(qt,k:t,v:t)=i=1∑t∑j=1tαt,jαt,ivi,αt,i=edkqtTki.
在这里, α t , i \alpha_{t,i} αt,i 表示 q t q_t qt 与 k i k_i ki 之间的注意力权重, d k d_k dk 是键(key)的特征维度。 随着序列长度的增加,注意力计算在整体计算成本中占据越来越大的比重,对长上下文处理提出了重大挑战。
算术强度(Arithmetic Intensity)是计算操作与内存访问的比率,它本质上决定了算法在硬件上的优化方式。 每块 GPU 都有一个关键的算术强度,该值由其峰值计算能力与内存带宽的比率计算得出。对于计算任务来说,当算术强度高于该临界值时,计算受限于 GPU 的浮点运算能力(FLOPS),即计算受限(compute-bound);而当算术强度低于该临界值时,计算受限于内存带宽(memory-bound)。
具体而言,在因果自注意力(causal self-attention)机制中,训练和预填充(prefilling)阶段的批量矩阵乘法和注意力计算具有较高的算术强度,使得这些阶段在现代加速器上属于计算受限(compute-bound)。相反,自回归解码(auto-regressive decoding)过程中,由于每次前向传播仅生成一个 token,但需要加载整个键值缓存(key-value cache),导致算术强度较低,使其受内存带宽限制(memory-bound)。这导致了不同的优化目标——在训练和预填充阶段需要降低计算成本,而在解码阶段需要减少内存访问。
3.2. 总体框架
为了利用具有天然稀疏模式(natural sparse pattern)的注意力机制潜力, 我们提出用更加紧凑且信息密度更高的表示键值对(representation key-value pairs) K ~ t , V ~ t \tilde{K}_t, \tilde{V}_t K~t,V~t替换方程 (1) 中的原始键值对 k : t , v : t k_{:t}, v_{:t} k:t,v:t。
具体而言,我们正式定义优化后的注意力输出如下:
K ~ t = f K ( q t , k : t , v : t ) , V ~ t = f V ( q t , k : t , v : t ) \tilde{K}_t = f_K (q_t, k_{:t}, v_{:t}), \quad \tilde{V}_t = f_V (q_t, k_{:t}, v_{:t}) K~t=fK(qt,k:t,v:t),V~t=fV(qt,k:t,v:t)
o t ∗ = Attn ( q t , K ~ t , V ~ t ) o^*_t = \text{Attn} (q_t,\tilde{K}_t, \tilde{V}_t) ot∗=Attn(qt,K~t,V~t)
其中, K ~ t \tilde{K}_t K~t、 V ~ t \tilde{V}_t V~t 是基于当前查询 q t q_t qt和上下文记忆 k : t , v : t k_{:t}, v_{:t} k:t,v:t 动态构造的。 我们可以设计不同的映射策略,以获取不同类别的 K ~ t c \tilde{K}_{t}^c K~tc、 V ~ t c \tilde{V}_{t}^c V~tc,并将它们组合如下:
o t ∗ = ∑ c ∈ C g t c ⋅ Attn ( q t , K ~ t c , V ~ t c ) . o^*_t = \sum_{c \in C} g_{t}^c \cdot \text{Attn}(q_t, \tilde{K}_{t}^c, \tilde{V}_{t}^c). ot∗=c∈C∑gtc⋅Attn(qt,K~tc,V~tc).
如图 2 所示,NSA 采用三种映射策略 C = { cmp , slc , win } C = \{\text{cmp}, \text{slc}, \text{win}\} C={cmp,slc,win},分别表示压缩(compression)、选择(selection)和滑动窗口(sliding window),用于键(keys)和值(values)。 g t c ∈ [ 0 , 1 ] g_{t}^c \in [0,1] gtc∈[0,1] 是对应策略 c c c 的门控得分,通过 MLP 和 sigmoid 激活从输入特征中获得。 令 N t N_t Nt 表示重新映射的键/值总数:
N t = ∑ c ∈ C size [ K ~ t c ] . N_t = \sum_{c \in C} \text{size}[\tilde{K}_{t}^c]. Nt=c∈C∑size[K~tc].
我们通过确保 N t ≪ t N_t \ll t Nt≪t 来保持较高的稀疏率。
3.3. 算法设计
在本小节中,我们介绍token重映射(remapping)策略 f K f_K fK 和 f V f_V fV 的设计,包括以下三种方法: Token Compression(令牌压缩)、 Token Selection(令牌选择) 、Sliding Window(滑动窗口)。
3.3.1. Token 压缩
通过将连续的键(keys)或值(values)块聚合为块级(block-level)表示,我们得到压缩后的键和值,它们能够捕获整个块的信息。 形式化地,压缩后的键表示定义如下:
K ~ t cmp = f K cmp ( k : t ) = { ϕ ( k i d + 1 : i d + l ) ∣ 1 ≤ i ≤ ⌊ t − d l ⌋ } \tilde{K}_{t}^{\text{cmp}} = f_{K}^{\text{cmp}}(k_{:t}) =\begin{Bmatrix} {{\phi(k_{id+1:id+l}) \quad \mid1 \leq i \leq \left\lfloor \frac{t - d}{l} \right\rfloor}}\end{Bmatrix} K~tcmp=fKcmp(k:t)={ϕ(kid+1:id+l)∣1≤i≤⌊lt−d⌋}
其中:
- l l l是块长度(block length),
- d d d 是相邻块之间的滑动步长(stride between adjacent blocks),
- ϕ \phi ϕ 是一个可学习的 MLP,配备块内位置编码(intra-block position encoding),用于将块中的键映射为单个压缩键。
压缩键的张量形式为:
K ~ t cmp ∈ R d k × ⌊ ( t − d ) / l ⌋ \tilde{K}_{t}^{\text{cmp}} \in \mathbb{R}^{d_k \times \lfloor (t - d) / l \rfloor} K~tcmp∈Rdk×⌊(t−d)/l⌋
通常,我们采用 d < l d < l d<l 以减少信息碎片化(mitigating information fragmentation)。
类似地,压缩值(compressed value) V ~ t cmp \tilde{V}_{t}^{\text{cmp}} V~tcmp 也遵循相同的公式。
压缩表示能够捕获:
- 更粗粒度的高级语义信息(coarser-grained higher-level semantic information),
- 降低注意力计算的计算负担(reduce computational burden of attention)。
3.3.2. Token选择
仅使用压缩的键和值可能会丢失重要的细粒度信息,因此我们有必要选择性地保留个别键和值。下面我们描述一种高效的token选择机制,该机制能够以较低的计算开销识别并保留最相关的令牌。
块级选择。 我们的选择策略在空间上以连续的块处理键和值序列,其动机来源于两个关键因素:硬件效率考虑和注意力分数的固有分布模式。块级选择对于在现代 GPU 上实现高效计算至关重要。这是因为现代 GPU 架构在连续块访问时的吞吐量远高于基于随机索引的读取。此外,块级计算能够优化张量核心 (Tensor Cores) 的利用率。这种架构特性使得块级内存访问和计算成为高性能注意力机制实现的基本原则,例如 FlashAttention 采用的基于块的设计。块级选择遵循注意力分数的固有分布模式。 先前的研究(Jiang et al., 2024)表明,注意力分数通常表现出空间连续性,这意味着相邻的键往往具有相似的重要性水平。我们的可视化结果(见第 6.2 节)也显示了这种空间连续模式。
为了实现块级选择,我们首先将键和值序列划分为选择块。为了确定注意力计算中最重要的块,我们需要为每个块分配重要性分数。下面介绍我们用于计算这些块级重要性分数的方法。
重要性分数计算。 计算块的重要性分数可能会引入较大的计算开销。幸运的是,压缩token的注意力计算会生成中间注意力分数,我们可以利用这些分数来推导选择块的重要性分数,其计算公式如下:
p t c m p = Softmax ( q t T K ~ t c m p ) p_{\text{t}}^{cmp} = \text{Softmax} \left( q_t^T \tilde{K}_{\text{t}}^{cmp} \right) ptcmp=Softmax(qtTK~tcmp)
其中, p cmp t ∈ R ⌊ t − d ⌋ p_{\text{cmp}}^t \in \mathbb{R}^{\lfloor t - d \rfloor} pcmpt∈R⌊t−d⌋ 表示查询向量 q t q_t qt 与压缩键 K ~ cmp t \tilde{K}_{\text{cmp}}^t K~cmpt 之间的注意力分数。设 l ′ l' l′ 为选择块的大小。当压缩块与选择块采用相同的分块方案,即 l ′ = l = d l' = l = d l′=l=d时,我们可以直接得到选择块的重要性分数: p slc t = p cmp t p_{\text{slc}}^t = p_{\text{cmp}}^t pslct=pcmpt对于分块方案不同的情况,我们根据它们的空间关系来计算选择块的重要性分数。在满足 d ∣ l d \mid l d∣l 和 d ∣ l ′ d \mid l' d∣l′ 的前提下,公式如下:
其中, [ ⋅ ] [·] [⋅] 表示向量元素的索引操作。
对于采用 GQA 或 MQA 机制的模型,由于键-值缓存 (KV cache) 在多个查询头 (query heads) 之间共享,因此需要确保各查询头之间的一致性块选择,以最小化解码过程中的 KV 缓存加载。组内查询头共享的重要性分数定义如下:
p t s l c ′ = ∑ h = 1 H p t s l c , ( h ) p_{\text{t}}^{slc'} = \sum_{h=1}^{H} p_{\text{t}}^{{slc},(h)} ptslc′=h=1∑Hptslc,(h)
其中,上标 ( h ) (h) (h) 表示查询头索引, H H H 是每个组内的查询头数量。该聚合操作确保了同一组内的查询头采用一致的块选择策略。
Top-𝑛 块选择。 在获取选择块的重要性分数后,我们保留按块重要性分数排名前 𝑛 的稀疏块内的令牌,其计算公式如下:
I t = { i ∣ rank ( p t s l c ′ [ i ] ) ≤ n } I_t = \{ i \ | \ \text{rank}(p_{\text{t}}^{slc'}[i]) \leq n \} It={i ∣ rank(ptslc′[i])≤n}
K ~ t s l c = Cat ( { k i l ′ + 1 : ( i + 1 ) l ′ ∣ i ∈ I t } ) \tilde{K}_{\text{t}}^{slc} = \text{Cat} \left( \{ k_{i l' +1 : (i+1) l'} \ | \ i \in I_t \} \right) K~tslc=Cat({kil′+1:(i+1)l′ ∣ i∈It})
其中, rank ( ⋅ ) \text{rank}(\cdot) rank(⋅)表示按降序排列的位置, rank = 1 \text{rank} = 1 rank=1对应最高分数, I t I_t It 是所选块的索引集合, Cat \text{Cat} Cat 表示拼接操作。
K ~ t s l c ∈ R d k × n l ′ \tilde{K}_{\text{t}}^{slc} \in \mathbb{R}^{d_k \times n l'} K~tslc∈Rdk×nl′ 是由压缩键组成的张量。类似的公式同样适用于细粒度值 V ~ t s l c \tilde{V}_{\text{t}}^{slc} V~tslc。
所选键和值随后将与 q t q_t qt一起参与注意力计算,如公式 (5) 所定义。
3.3.3. 滑动窗口
在注意力机制中,局部模式通常能够更快地适应并主导学习过程,这可能会导致模型无法有效利用压缩和选择token进行学习。为了解决这一问题,我们引入了一个专门的滑动窗口分支,明确处理局部上下文,使其他分支(压缩和选择)能够专注于学习各自的特征,而不会受到局部模式的捷径影响。 具体而言,我们在窗口 w w w 内维护最近的token: K ~ t win = k t − w : t , V ~ t win = v t − w : t \tilde{K}^{\text{win}}_t = k_{t-w:t}, \quad \tilde{V}^{\text{win}}_t = v_{t-w:t} K~twin=kt−w:t,V~twin=vt−w:t并将不同信息源(压缩token、选择token、滑动窗口)的注意力计算隔离到独立的分支中。然后,这些分支的输出通过一个可学习的门控机制进行聚合。 为了进一步防止不同注意力分支之间的捷径学习,同时保持较低的计算开销,我们为三个分支提供了独立的键和值。这种架构设计能够稳定学习,防止局部和长程模式识别之间的梯度干扰,同时引入的计算开销极小。
在获取三类键和值( K ~ t cmp , V ~ t cmp \tilde{K}^{\text{cmp}}_t, \tilde{V}^{\text{cmp}}_t K~tcmp,V~tcmp; K ~ t slc , V ~ t slc \tilde{K}^{\text{slc}}_t, \tilde{V}^{\text{slc}}_t K~tslc,V~tslc; K ~ t win , V ~ t win \tilde{K}^{\text{win}}_t, \tilde{V}^{\text{win}}_t K~twin,V~twin)后,我们按照公式 (5) 计算最终的注意力输出。结合上述压缩、选择和滑动窗口机制,构成了 NSA 的完整算法框架。
3.4. 核心设计
为了在训练和预填充(prefilling)阶段实现与 FlashAttention 相当的加速效果,我们基于 Triton 实现了硬件优化的稀疏注意力核心(kernel)。由于多头注意力(MHA)在解码阶段的内存开销较大且效率较低,我们专注于使用共享 KV 缓存(KV cache)的架构,例如 GQA(分组查询注意力)和 MQA(多查询注意力),这也是当前最先进的大模型(LLM)所采用的方案。 在我们的设计中,压缩注意力 和 滑动窗口注意力 计算可以直接兼容现有的 FlashAttention-2 核心,但 稀疏选择注意力 需要特殊的核心设计。如果按照 FlashAttention 的策略,将时间上连续的查询块加载到 SRAM(静态随机存取存储器),那么由于查询块中的不同查询可能需要访问非连续的 KV 块,会导致低效的内存访问。 为了解决这一问题,我们采用了一种不同的查询分组策略: 对于查询序列中的每个位置,我们将 GQA 组内的所有查询头一起加载到 SRAM,因为它们共享相同的稀疏 KV 块。 图 3 展示了我们前向传播的具体实现流程。该核心架构具有以下关键特性:
1. 以组为中心的数据加载(Group-Centric Data Loading)
在每个内循环(inner loop)中,加载组内所有注意力头的查询向量 Q ∈ R [ h , d k ] Q \in \mathbb{R}^{[h, d_k]} Q∈R[h,dk](位于位置 t t t),以及它们共享的稀疏键/值(KV)块索引 I t I_t It。
2. 共享 KV 读取(Shared KV Fetching)
在内循环中,按序加载由索引 I t I_t It 指定的连续键/值(KV)块到 SRAM(静态随机存取存储器),以 K ∈ R [ B k , d k ] K \in \mathbb{R}^{[B_k, d_k]} K∈R[Bk,dk]和 V ∈ R [ B k , d v ] V \in \mathbb{R}^{[B_k, d_v]} V∈R[Bk,dv] 的形式存储,其中 B k B_k Bk 为核心块大小,并满足 B k ∣ l ′ B_k | l' Bk∣l′(即 B k B_k Bk 是 l ′ l' l′ 的因子)。这种方式能够最小化内存加载,提高访问效率。
3. 在网格(Grid)上的外循环(Outer Loop on Grid)
由于内循环的长度(与所选块的数量 ( n ) 成正比)在不同的查询块之间几乎保持一致,我们将查询和输出循环放入 Triton 的网格调度器(grid scheduler),以简化并优化核心(kernel)执行效率。
该设计实现了近乎最优的计算密度(arithmetic intensity),主要通过以下两点:
- 消除冗余的 KV 传输,通过组内共享减少内存带宽占用。
- 均衡 GPU 流式多处理器(Streaming Multiprocessors, SMs)之间的计算负载,提高并行计算效率。
图 3 | NSA 的核心(Kernel)设计。 该核心按照 GQA 组(网格循环,Grid Loop)加载查询向量(queries),获取对应的稀疏 KV 块(内循环,Inner Loop),并在 SRAM 上执行注意力计算。绿色块表示存储在 SRAM 上的数据,蓝色块表示存储在 HBM(高带宽存储器)上的数据。
4. 实验
我们从三个角度评估 NSA:(1)通用基准测试性能,(2)长上下文基准测试性能,以及(3)链式思维(Chain-of-Thought)推理能力,并将其与全注意力(Full Attention)基线模型以及当前最先进的稀疏注意力方法进行对比。关于我们稀疏计算范式的效率分析将在第 5 节进行详细讨论,包括训练和推理速度的分析。
4.1. 预训练设置
遵循当前最先进大模型(LLMs)的常见做,我们的实验采用结合了分组查询注意力(Grouped-Query Attention, GQA)和专家混合(Mixture-of-Experts, MoE)的主干网络,该模型总参数量为 27B(270 亿),其中活跃参数量为 3B(30 亿)。
模型包含 30 层,隐藏维度设为 2560。在 GQA 机制中,我们将查询分组数设为 4,总注意力头数设为 64。对于每个注意力头,查询(query)、键(key)和值(value)的隐藏维度分别设为 d q = d k = 192 d_q = d_k = 192 dq=dk=192,值向量维度设为 d v = 128 d_v = 128 dv=128。
在 MoE 机制中,我们采用 DeepSeekMoE(Dai et al., 2024;DeepSeek-AI, 2024)结构,其中包含 72 个路由专家(routed experts)和 2 个共享专家(shared experts),并将Top-K 选取的专家数设为 6。为了确保训练稳定性,我们在第一层将 MoE 替换为 MLP,并采用SwiGLU 形式的前馈网络(FFN)。
图 4 | 27B 参数模型中全注意力(Full Attention)与 NSA 的预训练损失对比。两种模型均表现出稳定的收敛性,但 NSA 取得了更低的损失值。
表 1 | 全注意力基线(Full Attention)与 NSA 在通用基准测试上的预训练性能对比,涵盖知识(MMLU、MMLU-PRO、CMMLU)、推理(BBH、GSM8K、MATH、DROP)和代码(MBPP、HumanEval)任务。尽管具有较高的稀疏性,NSA 在大多数基准测试上仍取得了更优的平均性能。
所提出的架构在计算成本与模型性能之间实现了有效的权衡。对于 NSA,我们设置压缩块大小为 l = 32 l = 32 l=32,滑动步长为 d = 16 d = 16 d=16,选定块大小为 l ′ = 64 l' = 64 l′=64,选定块数量为 n = 16 n = 16 n=16(包括固定激活的 1 个初始块和 2 个局部块),以及滑动窗口大小为 w = 512 w = 512 w=512。 全注意力(Full Attention)模型和稀疏注意力(NSA)模型均在 8k 长度文本的 270B 令牌上进行预训练,随后采用 YaRN(Peng et al., 2024)在 32k 长度文本上继续训练并进行有监督微调,以实现长上下文适配。所有模型均训练至完全收敛,以确保公平比较。 如图 4 所示,NSA 的预训练损失曲线与全注意力基线相比表现出稳定且平滑的下降趋势,并且 NSA 始终优于全注意力模型。
4.2. 基线方法
除了与全注意力(Full Attention)进行比较外,我们还评估了几种最先进的推理阶段稀疏注意力方法,包括 H2O(Zhang et al., 2023b)、infLLM(Xiao et al., 2024)、Quest(Tang et al., 2024)以及 Exact-Top。Exact-Top 方法首先计算完整的注意力分数,然后为每个查询选择前 n n n 个得分最高的键,并在这些位置上计算注意力。 这些方法涵盖了不同的稀疏注意力范式,包括 KV 缓存淘汰(KV-cache eviction)、查询感知选择(query-aware selection)和精确的前 n n n 稀疏选择(exact top- n n n sparse selection)。
在通用评测中,大多数样本的长度均在稀疏注意力基线的局部上下文窗口范围内,因此这些方法在此场景下与全注意力方法基本等效。因此,我们仅在该设定下展示 NSA 与全注意力基线的比较结果。 在长上下文评测中,我们对所有基线方法进行比较,并确保所有稀疏注意力方法的稀疏度保持一致,以保证公平对比。对于链式思维推理(chain-of-thought reasoning)评测,该任务需要基于长文本进行有监督微调,而稀疏注意力基线方法不支持训练,因此我们仅与全注意力方法进行比较。
表 2 | 我们的 NSA 与基线方法在 LongBench 上的性能对比,包括单文档问答(QA)、多文档问答(QA)、合成任务和代码任务等子集。NSA 在大多数基线方法(包括全注意力)上均表现优越。
4.3. 性能对比
通用评估
我们在涵盖知识、推理和编程能力的综合基准测试套件上评估了预训练的 NSA 和全注意力(Full Attention)基线模型,包括 MMLU(Hendrycks et al., 2020)、MMLU-PRO(Wang et al., 2024)、CMMLU(Li et al., 2023)、BBH(Suzgun et al., 2022)、GSM8K(Cobbe et al., 2021)、MATH(Hendrycks et al., 2020)、DROP(Dua et al., 2019)、MBPP(Austin et al., 2021)和 HumanEval(Chen et al., 2021)。结果如表 1 所示。尽管 NSA 采用稀疏注意力机制,但其整体性能仍然优于所有基线方法,在 9 项指标中的 7 项上超越了全注意力模型。这表明尽管 NSA 在较短序列上未能完全发挥其计算效率优势,但其整体性能依然强劲。值得注意的是,NSA 在推理相关基准测试中表现出显著提升(DROP: +0.042,GSM8K: +0.034),表明我们的预训练方法有助于模型学习专门的注意力机制。这种稀疏注意力预训练机制迫使模型关注最重要的信息,可能通过过滤掉无关的注意力路径噪声来提升性能。此外,NSA 在各种评估任务中的一致表现也验证了其作为通用架构的稳健性。
长上下文评估
如图 5 所示,NSA 在 64k 上下文 “needle-in-a-haystack”(Kamradt, 2023)测试中,在所有位置上均实现了完美的检索准确率。这一性能得益于 NSA 的分层稀疏注意力设计,该设计结合了用于高效全局上下文扫描的压缩 token 和用于精确局部信息检索的选择 token。粗粒度压缩机制以低计算成本识别相关上下文块,而在选定 token 上进行的 token 级别注意力则确保了关键细粒度信息的保留。此设计使 NSA 既能保持全局感知能力,又能实现局部精确性。
我们还在 LongBench(Bai et al., 2023)上评估了 NSA,并与最先进的稀疏注意力方法及全注意力基线进行对比。为确保稀疏性一致,我们将所有稀疏注意力基线中每个查询激活的 token 数设为 2560,与 NSA 处理 32k 序列时的平均激活 token 数相对应。参考 Stream-LLM(Xiao et al., 2023),此 token 预算包括前 128 个引导 token 和 512 个局部 token。由于某些 LongBench 子集在所有模型上的得分较低,无法提供有意义的比较,我们将其排除。如表 2 所示,NSA 取得了最高的平均得分 0.469,超越所有基线(比全注意力高 +0.032,比 Exact-Top 高 +0.046)。这种提升源于两大关键创新:(1)原生的稀疏注意力设计,使得稀疏模式在预训练过程中能够进行端到端优化,实现稀疏注意力模块与其他模型组件的同步适配;(2)分层稀疏注意力机制在局部和全局信息处理之间取得了平衡。
值得注意的是,NSA 在需要对长上下文进行复杂推理的任务上表现尤为突出,在多跳问答(HPQ 和 2Wiki)任务中比全注意力提高了 +0.087 和 +0.051,在代码理解(LCC)任务中提高了 +0.069,在段落检索(PassR-en)任务中提高了 +0.075。这些结果验证了 NSA 处理各类长上下文任务的能力,其原生的稀疏注意力预训练提供了额外的学习任务最优模式的能力。
图 5 | 64k 上下文长度下不同位置的 Needle-in-a-Haystack 检索准确率。 NSA 通过其分层稀疏注意力设计实现了完美的检索准确率。
链式思维推理评估
为了评估 NSA 在先进下游训练范式中的兼容性,我们研究了其通过后训练学习链式思维(Chain-of-Thought)数学推理能力的能力。鉴于强化学习在小规模模型上的有效性有限,我们采用 DeepSeek-R1 进行知识蒸馏,并对 10B 令牌的 32k 长度数学推理轨迹进行监督微调(SFT)。由此训练出两种可比模型:Full Attention-R(全注意力基线)和 NSA-R(我们的稀疏变体)。我们在具有挑战性的美国邀请数学竞赛(AIME 24)基准测试上评估了这两种模型。在推理时,我们采用 0.7 采样温度和 0.95 的 top-p 取样,为每个问题生成 16 个回答,并计算平均得分。
为了验证推理深度的影响,我们进行了两个不同上下文限制(8k 和 16k token)的实验,以衡量延长推理链是否能提高准确率。模型预测的示例对比见附录 A。
表 3 | AIME 基于指令的评估结果(经过监督微调后)。 我们的 NSA-R 在 8k 和 16k 序列长度下均优于 Full Attention-R。
图 6 | Triton 版 NSA 内核与 Triton 版 FlashAttention-2 内核的对比。我们的实现显著降低了所有上下文长度下的延迟,且随着输入长度的增加,提升更加明显。
如表 3 所示,在 8k 上下文设置下,NSA-R 的准确率显著高于 Full Attention-R(+0.075),这一优势在 16k 上下文设置下仍然保持(+0.054)。
这些结果验证了原生稀疏注意力的两个关键优势:
(1) 预训练的稀疏注意力模式能够高效捕捉远程逻辑依赖关系,这对于复杂的数学推导至关重要;
(2) 我们的架构在硬件对齐的设计下保持了足够的上下文密度,支持更深层次的推理而不会导致灾难性遗忘。
在不同上下文长度下的持续优越表现,进一步证实了当稀疏注意力原生集成到训练流程中时,其在高级推理任务中的可行性。
5. 效率分析
我们在 8-GPU A100 系统上评估 NSA 相对于 Full Attention 的计算效率。在效率分析中,我们还将模型配置为 GQA 组数 g = 4 g = 4 g=4,每组头数 h = 16 h = 16 h=16,查询/键维度 d k = 192 d_k = 192 dk=192,值维度 d v = 128 d_v = 128 dv=128。遵循第 4 节中的相同设置,我们设定 NSA 的压缩块大小 l = 32 l = 32 l=32,滑动步长 d = 16 d = 16 d=16,选定块大小 l ′ = 64 l' = 64 l′=64,选定块数 n = 16 n = 16 n=16,以及滑动窗口大小 w = 512 w = 512 w=512。
5.1. 训练速度
我们比较了基于 Triton 实现的 NSA 注意力机制和 Full Attention,与基于 Triton 的 FlashAttention-2 进行对比,以确保在相同的后端上进行公平的速度比较。如图 6 所示,随着上下文长度的增加,NSA 的加速效果逐步提升,在 64k 上下文长度下,前向传播速度最高可达 9.0×,反向传播速度可达 6.0×。值得注意的是,随着序列长度的增长,NSA 的速度优势变得更加显著。
这种加速主要源于我们面向硬件优化的稀疏注意力算法设计,以最大化计算效率:
(1) 基于块的内存访问模式 通过合并加载(coalesced loads)最大化 Tensor Core 的利用率;
(2) 精细化的内核循环调度 消除了冗余的 KV 传输,提高了计算效率。
表 4 | 解码过程中每次注意力操作的内存访问量(以等效 token 数表示)。由于解码的算术强度较低且受内存带宽限制,预计加速比大致与内存访问量呈线性关系。
5.2. 解码速度
注意力机制的解码速度主要受内存访问瓶颈的限制,这与 KV 缓存加载量密切相关。在每个解码步骤中,我们的 NSA 仅需加载最多
⌊(𝑠−𝑙) / 𝑑⌋ 个压缩 token” token、𝑛𝑙′ 个选定 token 以及 𝑤 个邻近 token,其中 𝑠 为缓存序列长度。正如表 4 所示,随着解码长度的增加,我们的方法显著减少了延迟,在 64k 上下文长度下实现了高达 11.6× 的加速比。这种内存访问效率的提升在更长的序列上更加明显。
6. 讨论
在本节中,我们回顾 NSA 的开发过程,并讨论在探索不同稀疏注意力策略时获得的关键见解。尽管我们的方法展示了有前景的结果,但理解替代策略所遇到的挑战以及分析注意力模式能够为未来的研究方向提供重要的背景信息。我们首先探讨促使我们做出当前设计选择的替代 token 选择策略所面临的挑战,然后通过可视化分析注意力分布模式。
6.1. 替代 token 选择策略的挑战
在设计 NSA 之前,我们尝试将现有的稀疏注意力方法应用于训练阶段。然而,这些尝试遇到了诸多挑战,促使我们设计了一种新的稀疏注意力架构:
基于 Key 聚类的策略
我们研究了类似 ClusterKV(Liu 等, 2024)等基于聚类的方法,这些方法将相同聚类的 Key 和 Value 存储在连续的内存区域中。尽管这些方法在理论上可用于训练和推理,但它们面临三大挑战:
- 动态聚类机制引入了非平凡的计算开销;
- 由于聚类间的不均衡,算子优化难度增加,特别是在专家混合(MoE)系统中,倾斜的专家并行(EP)组执行时间导致持续的负载不均衡;
- 由于必须进行周期性重新聚类以及分块顺序训练协议的约束,导致实现上的限制。
这些因素共同造成了严重的瓶颈,显著限制了其在实际部署中的有效性。
图 7 | 在一个具有 30 亿参数的模型上,对比 Full Attention 和不同的 Token 选择策略的训练损失。我们的 NSA 取得了更好的性能。
图 8 | 全注意力 Transformer 的注意力图可视化。浅色区域表示较高的注意力值。 如图所示,注意力分数呈现块状聚集分布。
其他块状选择策略。我们还考虑了与 NSA 不同的块状键值选择策略,例如 Quest (Tang et al., 2024) 和 InfLLM (Xiao et al., 2024)。这些方法依赖于为每个块计算一个重要性分数,并根据其与查询 𝑞𝑡 的相似度选择前 𝑛 个块。然而,现有方法面临两个关键问题:
(1) 由于选择操作是不可微的,基于神经网络的重要性分数计算依赖于辅助损失,这会增加算子开销,并且往往降低模型性能;
(2) 基于启发式、无参数的计算策略存在较低的召回率,导致次优的性能表现。
我们在具有类似架构的 3B 参数模型上评估了这两种方法,并将其损失曲线与 NSA 和 Full Attention 进行了比较。对于基于辅助损失的选择方法,我们为每个块引入了额外的查询和代表性键,以估计块的重要性分数。这些分数由原始查询与块内键之间的平均注意力分数进行监督。对于基于启发式、无参数的选择方法,我们遵循 Quest 的策略,使用查询与键块的逐坐标 min-max 乘积进行直接选择,而不引入额外参数。
此外,我们还探索了一种冷启动训练方法,即在训练的前 1000 步使用 Full Attention,然后再切换到启发式块状选择策略。如图 7 所示,这两种方法的损失均较差。
6.2. 可视化
为了探索 Transformer 注意力分布中的潜在模式,并为我们的设计寻找灵感,我们在图 8 中可视化了预训练的 27B Full Attention 模型的注意力图。该可视化揭示了一个有趣的现象,即注意力分数往往表现出块状聚集特性,相邻的键通常具有相似的注意力分数。
这一观察结果启发了 NSA 的设计,表明基于空间连续性选择键块可能是一个有效的策略。块状聚集现象表明,序列中相邻的 token 可能与查询 token 共享某些语义关系,尽管这些关系的具体性质仍需进一步研究。这一发现促使我们探索一种基于连续 token 块的稀疏注意力机制,以提高计算效率,同时保留高注意力模式。
7. 相关工作
我们回顾了现有通过稀疏注意力提高注意力计算效率的方法。这些方法可以根据其核心策略大致分为三类:(1) 固定稀疏模式,(2) 动态 token 剪枝,(3) 查询感知选择。我们介绍了每个类别中的一些代表性工作。
7.1. 固定稀疏模式
SlidingWindow 是一种常用的方法,它仅允许查询在固定窗口内计算注意力。StreamingLLM (Xiao et al., 2023) 通过维护上下文的两个关键部分(早期 token 作为注意力汇聚区和局部上下文窗口)来解决长文本流处理的挑战。虽然这些方法有效降低了内存和计算成本,但它们刚性地忽略部分上下文,限制了在需要完整上下文理解的任务上的表现。
7.2. 动态Token剪枝
H2O (Zhang et al., 2023b) 采用了一种自适应方法,在解码过程中减少 KV-cache 的内存使用。该方法根据注意力分数动态淘汰对未来预测不太重要的 token。SnapKV (Li et al., 2024) 也引入了一种 token 剪枝策略,通过有选择地保留最关键的特征来减少 KV 缓存,从而实现高效的内存利用。SnapKV 通过注意力权重分析和投票机制在预填充过程中识别重要特征,并在更新 KV 缓存时,将选定的压缩特征与最新上下文结合,以保持提示一致性。
7.3. 查询感知选择
Quest (Tang et al., 2024) 采用了一种块状选择策略,通过计算查询与键块的坐标维度 min-max 乘积来估计每个块的重要性,并根据得分选择最重要的前 𝑛 个键值块进行注意力计算。InfLLM (Xiao et al., 2024) 结合了固定模式和检索机制,通过维护注意力汇聚区、本地上下文和可检索块来进行选择,该方法从每个块中选取代表性键来估计块的重要性。HashAttention (Desai et al., 2024) 将关键 token 识别建模为推荐问题,通过学习的函数将查询和键映射到哈希空间进行计算。ClusterKV (Liu et al., 2024) 通过首先对键进行聚类,然后基于查询-簇相似度选择最相关的簇进行注意力计算,以实现稀疏性。
8. 结论
我们提出了 NSA,一种面向硬件对齐的稀疏注意力架构,用于高效的长上下文建模。通过在可训练架构中集成分层 token 压缩与块状 token 选择,我们的架构在保持 Full Attention 性能的同时,实现了加速训练和推理。NSA 推动了当前技术的发展,表现在通用基准测试中表现与 Full Attention 基线相匹配,在长上下文评估中超越建模能力,并增强推理能力,同时伴随计算延迟的可测量降低,实现显著的加速。