快速理清 Attention 注意力和 Encoder, Decoder 概念

发布于:2025-03-09 ⋅ 阅读:(13) ⋅ 点赞:(0)

之前一直以为 Attention 和 RNN 没关系是凭空蹦出来的新概念;以为 Transformer, Encoder, Decoder 这几个概念是绑在一起的。并不尽然。

Encoder 和 Decoder

在这里插入图片描述
RNN 里就有 Encoder Decoder 的概念。其中,encoder 接受用户输入,写入 hidden state。Decoder 接受之前时刻的隐状态,并生成 logits。类似的架构也出现在 CNN 图像模型中。

所以,不论如何,只要是数据流长得像 encode, decode 的,都是 Encoder, Decoder

Attention 普遍意义上的注意力机制

请添加图片描述
上面 RNN 的问题是,decoder 只能拿到 encoder 最后的这个 <end> 位置的 feature,相当于必须串行接收整个输入,不能有注意力地选择输入序列的重点(不能加权)。

所以,我们想实现一个类似全连接的功能,在每个 decode 的位置,给输入序列的隐状态加个系数,共同喂给 decoder。所以,注意力其实就是把上面的这个序列算个系数。

但是怎么能让这个全连接矩阵可训练,可泛化是个问题。注意力机制引入了 Q, K, V 三个概念,其中 K, V 是 n 个 kv pair,Query 表示上图上面的部分,最后,Q 和 K 会两两一组算一个相关系数,然后用相关系数乘上 v,作为注意力输出。

其中,Q, K 表示。一个例子是我看涩图的注意力集中在人脸上,Q = 我; K = 涩图(V 和 K 严格绑定,是另一个空间对 K 的表示)Q,K 算一个相似度赋给 V.

一般 K = V。
请添加图片描述

请添加图片描述

自注意力机制

注意力是一个很宽泛的概念,不知道 QKV 是什么,自注意力机制则是规定了 QKV 同源,都是通过原始输入 X X X 乘上线性矩阵 W q , W k , W v W^q, W^k, W^v Wq,Wk,Wv 产生的。

请添加图片描述

给定输入矩阵 X X X(形状为 ( n , d ) (n, d) (n,d),其中 n n n 是序列长度, d d d 是嵌入维度),计算 Query(查询)、Key(键)、Value(值):
Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV
其中:

  • W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 是可训练的权重矩阵(形状均为 ( d , d k ) (d, d_k) (d,dk))。
  • Q , K , V Q, K, V Q,K,V 的形状均为 ( n , d k ) (n, d_k) (n,dk)

2. 计算注意力分数(Scaled Dot-Product Attention)

A = Q K T d k A = \frac{Q K^T}{\sqrt{d_k}} A=dk QKT
其中:

  • K T K^T KT 是 Key 矩阵的转置(形状为 ( d k , n ) (d_k, n) (dk,n)),使得 Q K T QK^T QKT 形状为 ( n , n ) (n, n) (n,n)
  • 1 d k \frac{1}{\sqrt{d_k}} dk 1 是缩放因子,防止大数值影响梯度。

3. 计算注意力权重(Softmax 归一化)

α = softmax ( A ) \alpha = \text{softmax}(A) α=softmax(A)
其中, α \alpha α 形状为 ( n , n ) (n, n) (n,n),表示序列中每个位置对其他位置的注意力权重。

4. 计算加权 Value

Z = α V Z = \alpha V Z=αV
其中:

  • Z Z Z 形状为 ( n , d k ) (n, d_k) (n,dk),即每个输入位置的加权输出。

5. 多头注意力(Multi-Head Attention)

如果使用 h h h 个头,每个头分别计算:
Z i = Attention ( X W Q i , X W K i , X W V i ) Z_i = \text{Attention}(X W_{Q_i}, X W_{K_i}, X W_{V_i}) Zi=Attention(XWQi,XWKi,XWVi)
然后将多个头的结果拼接并映射回原始维度:
Z = [ Z 1 , Z 2 , … , Z h ] W O Z = [Z_1, Z_2, \dots, Z_h] W_O Z=[Z1,Z2,,Zh]WO
其中:

  • W O W_O WO 是输出投影矩阵(形状为 ( h ⋅ d k , d ) (h \cdot d_k, d) (hdk,d))。
  • Z Z Z 形状回到 ( n , d ) (n, d) (n,d)

Ref

https://zhuanlan.zhihu.com/p/109585084
https://www.cnblogs.com/nickchen121/p/16470710.html

https://www.cnblogs.com/nickchen121/p/16470711.html


网站公告

今日签到

点亮在社区的每一天
去签到