MTA 论文
在 Transformer 中计算注意力权重时,仅依赖单个 Q 和 K 的相似度,无法有效捕捉多标记组合信息。(对于 A、B 两个词,单标记注意力需要分别计算两个词的注意力分数,再通过后处理定位共同出现的位置或通过多层隐式堆叠,增加模型深度和容量)。MTA 显示建模多标记依赖,同时不牺牲全局交互和额外参数。(通过卷积运算让他能够看到邻近的Q、K 以及其他注意力头的信息)
在 Transformer 其他部分,如 FFN 的输入/输出加卷积,主要是为了捕捉词元表示之间的局部依赖关系,不直接改变注意力机制本身如何计算相关性。
MTA 的卷积直接作用在 Q K T / A QK^T/A QKT/A,意味着卷积直接参与了决定哪些上下文位置应该被关注的过程,在处理词元间的关系强度。
提出两种方式:pre-softmax convolution 和 post-softmax convolution,MTA 默认采用 Pre-softmax Q-K Convolution 和 Post-softmax Head Mixing Convolution。二者区别在于是在 softmax 之前还是之后进行。
Q-K convolution
a i j = S o f t m a x ( ∑ i ′ = 0 c q − 1 ∑ j ′ = − ⌊ c k / 2 ⌋ ⌈ c k / 2 ⌉ − 1 1 i ≥ j − j ′ θ i ′ , j ′ q i − i ′ k j − j ′ ⊤ / d ) ( 1 ) a_{ij}=\mathrm{Softmax}\left(\sum_{i^{\prime}=0}^{c_{q}-1}\sum_{j^{\prime}=-\lfloor c_{k}/2\rfloor}^{\lceil c_{k}/2\rceil-1}\mathbf{1}_{i\geq j- j^{\prime}}\theta_{i^{\prime},j^{\prime}}q_{i-i^{\prime}}k_{j-j^{\prime}}^{\top}/\sqrt{d}\right) \qquad \qquad(1) aij=Softmax
i′=0∑cq−1j′=−⌊ck/2⌋∑⌈ck/2⌉−11i≥j−j′θi′,j′qi−i′kj−j′⊤/d
(1)
在卷积中,为防止未来信息泄露,需要做 Masking。理想的 Masking 比较复杂(见式(1)),采用一种简化形式:用 0 Mask 掉未来的 Q K T QK^T QKT 值,做卷积,再用 − ∞ -\infty −∞ Mask 掉结果中非法位置,再做 Softmax。
A = S o f t m a x ( M a s k − ∞ ( C o n v 2 d θ ( M a s k 0 ( A ^ ) ) ) ) . A=\mathrm{Softmax}\left(\mathrm{Mask}_{-\infty}\left(\mathrm{Conv}2\mathrm{d}_\theta\left(\mathrm{Mask}_0(\hat{A})\right)\right)\right). A=Softmax(Mask−∞(Conv2dθ(Mask0(A^)))).
Head Mixing Convolution
允许不同注意力头之间共享信息,放大重要信号。将 M 个头分成 M / c h M/c_h M/ch 个组,每组 c h c_h ch 个头。在每组的头内左 1D 卷积。同样可以在 softmax 之前或之后进行。
Group Normalization with depth scaling
改善梯度流,对抗深层网络中残差连接可能带来的主导效应(让模型更关注注意力部分输出,而不是仅仅传递上一层信息)。
在每个头的输出上独立应用组归一化,并结合一个随层数变化的缩放因子。
核心矛盾:在「增强注意力精度」和「保持计算效率」之间尚未找到完美平衡,当前更适合对计算资源不敏感的高精度场景。
实验结果
1.找字母块任务,验证 MTA 能够解决 [多条件匹配] 问题。
MTA 错误率接近 0%,而 Transformer 失败率超 50%。
2.LLM,在 105B 词元数据上训练 880M 参数模型
- MTA 仅在 1/4 的层 使用 Key-Query 卷积(核大小: c q = 6 , c k = 11 c_q=6,c_k=11 cq=6,ck=11)。
- 所有层使用 Head 卷积(核大小 c h = 2 c_h=2 ch=2)。