一、前言:为什么QKV这么重要?
Transformer模型彻底改变了自然语言处理领域,而其中的核心——注意力机制(Attention)的QKV三要素,是理解Transformer的关键。很多初学者看到Q(Query)、K(Key)、V(Value)就一头雾水:它们到底是什么?从哪里来?为什么需要它们?
本文将用最直观的方式,带你彻底理解QKV的原理,配有详细流程图和可运行代码,保证小白也能轻松掌握!
二、QKV的本质:它们从哪里来?
2.1 QKV的来源
QKV都来自同一个地方——我们的输入数据! 只是通过不同的"通道"(线性变换)变成了三种不同形式。
具体过程:
- 输入序列首先被转换为词向量矩阵X
- X分别与三个不同的权重矩阵WQ、WK、WVW^Q、W^K、W^VWQ、WK、WV相乘
- 得到:Q=X⋅WQ,K=X⋅WK,V=X⋅WVQ = X·W^Q,K = X·W^K,V = X·W^VQ=X⋅WQ,K=X⋅WK,V=X⋅WV
这些权重矩阵WQ、WK、WVW^Q、W^K、W^VWQ、WK、WV是模型在训练过程中自动学习得到的参数,不是预先设定好的。
2.2 一个小比喻:图书馆找书
想象你去图书馆找书:
- Query(Q) 就是你想查找的"问题"(比如"人工智能入门")
- Key(K) 就是图书馆里每本书的"标题标签"
- Value(V) 就是每本书的"实际内容"
当你(Q)在图书馆里寻找时,系统会将你的查询与所有书的标题(K)进行匹配,找到最相关的书,然后把书的内容(V)呈现给你。
三、自注意力机制:让词自己关注自己
3.1 Self-Attention的核心特点
在Transformer的自注意力机制(Self-Attention) 中,Q、K、V都来自同一段文本。这意味着,当模型处理一个句子时,每个词都会:
- 作为"提问者"(Q)去询问其他词
- 作为"被询问对象"(K)被其他词提问
- 作为"回答内容"(V)提供给提问者参考
比如处理句子"我喜欢吃苹果"时:
- “我"会作为Q去关注"喜欢”、“吃”、“苹果”
- 同时"我"也会作为K和V被其他词关注
3.2 注意力计算的完整流程
计算步骤详解:
- 计算注意力分数:Q和K做点积,衡量它们之间的相关性
- 缩放:除以√d_k(d_k是K向量的维度),防止点积过大
- Softmax:将分数转换为概率分布(注意力权重)
- 加权求和:用这些权重对V进行加权求和,得到最终输出
公式表示为:Attention(Q,K,V)=Softmax(QKT/√dk)VAttention(Q, K, V) = Softmax(QKᵀ/√d_k)VAttention(Q,K,V)=Softmax(QKT/√dk)V
四、多头注意力:Transformer的真正创新
4.1 为什么需要多头?
单头注意力有个局限:它只能学习一种"关注模式"。就像你只能用一种方式看世界!
多头注意力(Multi-Head Attention) 让模型能够:
- 同时从不同子空间学习信息
- 捕捉更丰富的语言特征
- 提高模型的表达能力
4.2 多头注意力工作原理
关键步骤:
- 将Q、K、V分别拆分成h个"头"(例如8头或16头)
- 每个头独立计算注意力
- 将所有头的输出拼接起来
- 通过一个线性变换得到最终输出
五、位置编码:让模型知道词的顺序
5.1 为什么需要位置编码?
自注意力机制本身不包含位置信息——它把输入当作一个"词袋",不知道词的顺序!
解决方案:添加位置编码(Positional Encoding),让模型知道每个词的位置。
5.2 位置编码与QKV的关系
位置编码在输入阶段就加到词向量上,然后再生成QKV:
- 输入 = 词向量 + 位置编码
- Q = (词向量 + 位置编码)·WQW^QWQ
- K = (词向量 + 位置编码)·WKW^KWK
- V = (词向量 + 位置编码)·WVW^VWV
六、完整代码实现
6.1 单头自注意力实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
"""单头自注意力机制"""
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
# 三个可学习的权重矩阵
self.W_q = nn.Linear(embed_size, embed_size, bias=False)
self.W_k = nn.Linear(embed_size, embed_size, bias=False)
self.W_v = nn.Linear(embed_size, embed_size, bias=False)
def forward(self, x):
"""
x: 输入张量,形状为(batch_size, seq_length, embed_size)
"""
# 生成Q, K, V
Q = self.W_q(x) # [batch_size, seq_length, embed_size]
K = self.W_k(x) # [batch_size, seq_length, embed_size]
V = self.W_v(x) # [batch_size, seq_length, embed_size]
# 计算注意力分数: Q·K^T
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # [batch_size, seq_length, seq_length]
# 缩放: 除以√d_k
d_k = self.embed_size
attn_scores = attn_scores / (d_k ** 0.5)
# Softmax获取注意力权重
attn_weights = F.softmax(attn_scores, dim=-1)
# 加权求和: attn_weights·V
output = torch.matmul(attn_weights, V) # [batch_size, seq_length, embed_size]
return output, attn_weights
# 使用示例
batch_size = 2
seq_length = 4
embed_size = 8
x = torch.randn(batch_size, seq_length, embed_size) # 随机输入
attention = SelfAttention(embed_size)
output, attn_weights = attention(x)
print("输入形状:", x.shape) # torch.Size([2, 4, 8])
print("输出形状:", output.shape) # torch.Size([2, 4, 8])
print("注意力权重形状:", attn_weights.shape) # torch.Size([2, 4, 4])
输出结果:
输入形状: torch.Size([2, 4, 8])
输出形状: torch.Size([2, 4, 8])
注意力权重形状: torch.Size([2, 4, 4])
6.2 多头注意力实现(Transformer核心)
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self,
dim, # 输入token的dim
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0.,
proj_drop_ratio=0.):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x):
# [batch_size, num_patches + 1, total_embed_dim]
B, N, C = x.shape
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# 示例调用
if __name__ == "__main__":
# 创建输入张量 [batch_size, num_patches+1, embed_dim]
batch_size = 2
num_patches = 196 # 例如ViT中的14x14图像块
embed_dim = 768 # 与dim参数一致
x = torch.randn(batch_size, num_patches + 1, embed_dim)
# 初始化Attention模块
attention = Attention(
dim=embed_dim,
num_heads=8,
qkv_bias=True,
attn_drop_ratio=0.1,
proj_drop_ratio=0.1
)
# 前向传播
output = attention(x)
# 验证输出形状 (应与输入形状相同)
print("Input shape:", x.shape) # torch.Size([2, 197, 768])
print("Output shape:", output.shape) # torch.Size([2, 197, 768])
七、注意力可视化:看看模型在关注什么
7.1 如何可视化注意力
我们可以将注意力权重可视化为热力图,直观地看到模型在处理句子时关注了哪些词。
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import matplotlib.font_manager as fm
def visualize_attention(sentence, attn_weights, head=0, font_path=None):
"""
可视化注意力权重(优化中文显示版本)
sentence: 原始句子(词列表)
attn_weights: 注意力权重 [heads, seq_len, seq_len]
head: 要可视化的头
font_path: 字体文件路径(可选,用于更好地显示中文)
"""
# 设置中文字体支持
if font_path:
# 如果指定了字体路径,则使用指定字体
prop = fm.FontProperties(fname=font_path)
plt.rcParams['font.family'] = prop.get_name()
else:
# 尝试使用系统默认中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'FangSong', 'Microsoft YaHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False # 解决负号'-'显示为方块的问题
plt.figure(figsize=(10, 8))
sns.heatmap(
attn_weights[head].detach().numpy(),
xticklabels=sentence,
yticklabels=sentence,
cmap='viridis',
annot=True,
fmt=".2f",
square=True # 使单元格为正方形
)
plt.title(f'注意力头 {head + 1}', fontsize=16)
plt.xlabel('键 (Key)', fontsize=12)
plt.ylabel('查询 (Query)', fontsize=12)
# 旋转标签以便更好地显示
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
# 示例使用
sentence = ["我", "喜欢", "吃", "苹果"]
# 假设我们有一个4头注意力,取第一个头
attn_weights = torch.rand(4, 4, 4) # 随机生成注意力权重
print("原始句子:", sentence)
print("注意力权重形状:", attn_weights.shape)
# 可视化第一个注意力头
visualize_attention(sentence, attn_weights, head=0)
八、不同场景下的QKV:Encoder-Decoder注意力
8.1 Transformer架构中的两种注意力
Transformer包含两种主要的注意力机制:
- Encoder Self-Attention:
- Q、K、V都来自Encoder输入
- 用于理解输入序列内部关系
- Decoder-Encoder Attention:
- Q来自Decoder
- K和V来自Encoder
- 让Decoder在生成输出时关注输入的关键部分
九、总结与学习建议
9.1 QKV核心要点总结
概念 | 说明 | 关键点 |
---|---|---|
Q(Query) | “提问者”,表示当前关注点 | 决定"我想知道什么" |
K(Key) | “标签”,表示内容特征 | 决定"这段内容关于什么" |
V(Value) | “内容”,表示实际信息 | 决定"提供什么信息" |
Self-Attention | QKV来自同一输入 | 捕捉序列内部关系 |
Multi-Head | 多组QKV并行计算 | 从不同角度理解输入 |
9.2 学习建议
- 动手实践:运行提供的代码,修改参数观察输出变化
- 可视化探索:使用BERT等预训练模型查看真实注意力分布
- 深入阅读:阅读原始论文《Attention is All You Need》
- 扩展学习:了解Transformer变体(如BERT、GPT)中的QKV应用
十、结语
QKV不是神秘的魔法,而是通过简单线性变换从输入数据中生成的——只是这些变换的参数是模型在训练中自动学习得到的。理解QKV是掌握Transformer的关键一步,希望本文的详细解释、流程图和代码能帮助你彻底掌握这一核心概念!
记住:Q是"我想知道什么",K是"这段内容关于什么",V是"这段内容的实际信息"。当你下次看到QKV时,就想象自己在图书馆里找书,这就是注意力机制的精髓!