Transformer中的QKV揭秘:从入门到实践(含流程图)

发布于:2025-08-31 ⋅ 阅读:(28) ⋅ 点赞:(0)

一、前言:为什么QKV这么重要?

Transformer模型彻底改变了自然语言处理领域,而其中的核心——注意力机制(Attention)的QKV三要素,是理解Transformer的关键。很多初学者看到Q(Query)、K(Key)、V(Value)就一头雾水:它们到底是什么?从哪里来?为什么需要它们?

本文将用最直观的方式,带你彻底理解QKV的原理,配有详细流程图可运行代码,保证小白也能轻松掌握!

二、QKV的本质:它们从哪里来?

2.1 QKV的来源

QKV都来自同一个地方——我们的输入数据! 只是通过不同的"通道"(线性变换)变成了三种不同形式。

输入序列X
线性变换 W^Q
线性变换 W^K
线性变换 W^V
Query Q
Key K
Value V

具体过程:

  1. 输入序列首先被转换为词向量矩阵X
  2. X分别与三个不同的权重矩阵WQ、WK、WVW^Q、W^K、W^VWQWKWV相乘
  3. 得到:Q=X⋅WQ,K=X⋅WK,V=X⋅WVQ = X·W^Q,K = X·W^K,V = X·W^VQ=XWQK=XWKV=XWV

这些权重矩阵WQ、WK、WVW^Q、W^K、W^VWQWKWV是模型在训练过程中自动学习得到的参数,不是预先设定好的。

2.2 一个小比喻:图书馆找书

想象你去图书馆找书:

  • Query(Q) 就是你想查找的"问题"(比如"人工智能入门")
  • Key(K) 就是图书馆里每本书的"标题标签"
  • Value(V) 就是每本书的"实际内容"

当你(Q)在图书馆里寻找时,系统会将你的查询与所有书的标题(K)进行匹配,找到最相关的书,然后把书的内容(V)呈现给你。

三、自注意力机制:让词自己关注自己

3.1 Self-Attention的核心特点

在Transformer的自注意力机制(Self-Attention) 中,Q、K、V都来自同一段文本。这意味着,当模型处理一个句子时,每个词都会:

  1. 作为"提问者"(Q)去询问其他词
  2. 作为"被询问对象"(K)被其他词提问
  3. 作为"回答内容"(V)提供给提问者参考

比如处理句子"我喜欢吃苹果"时:

  • “我"会作为Q去关注"喜欢”、“吃”、“苹果”
  • 同时"我"也会作为K和V被其他词关注

3.2 注意力计算的完整流程

在这里插入图片描述

计算步骤详解:

  1. 计算注意力分数:Q和K做点积,衡量它们之间的相关性
  2. 缩放:除以√d_k(d_k是K向量的维度),防止点积过大
  3. Softmax:将分数转换为概率分布(注意力权重)
  4. 加权求和:用这些权重对V进行加权求和,得到最终输出

公式表示为:Attention(Q,K,V)=Softmax(QKT/√dk)VAttention(Q, K, V) = Softmax(QKᵀ/√d_k)VAttention(Q,K,V)=Softmax(QKT/√dk)V

四、多头注意力:Transformer的真正创新

4.1 为什么需要多头?

单头注意力有个局限:它只能学习一种"关注模式"。就像你只能用一种方式看世界!

多头注意力(Multi-Head Attention) 让模型能够:

  • 同时从不同子空间学习信息
  • 捕捉更丰富的语言特征
  • 提高模型的表达能力

4.2 多头注意力工作原理

输入X
线性变换生成QKV
拆分为多个头
头1: 计算注意力
头2: 计算注意力
......
头n: 计算注意力
拼接所有头
线性变换
多头注意力输出

关键步骤:

  1. 将Q、K、V分别拆分成h个"头"(例如8头或16头)
  2. 每个头独立计算注意力
  3. 将所有头的输出拼接起来
  4. 通过一个线性变换得到最终输出

五、位置编码:让模型知道词的顺序

5.1 为什么需要位置编码?

自注意力机制本身不包含位置信息——它把输入当作一个"词袋",不知道词的顺序!

解决方案:添加位置编码(Positional Encoding),让模型知道每个词的位置。

5.2 位置编码与QKV的关系

词向量
输入Transformer
位置编码
生成QKV
注意力计算

位置编码在输入阶段就加到词向量上,然后再生成QKV:

  • 输入 = 词向量 + 位置编码
  • Q = (词向量 + 位置编码)·WQW^QWQ
  • K = (词向量 + 位置编码)·WKW^KWK
  • V = (词向量 + 位置编码)·WVW^VWV

六、完整代码实现

6.1 单头自注意力实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    """单头自注意力机制"""
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        # 三个可学习的权重矩阵
        self.W_q = nn.Linear(embed_size, embed_size, bias=False)
        self.W_k = nn.Linear(embed_size, embed_size, bias=False)
        self.W_v = nn.Linear(embed_size, embed_size, bias=False)
        
    def forward(self, x):
        """
        x: 输入张量,形状为(batch_size, seq_length, embed_size)
        """
        # 生成Q, K, V
        Q = self.W_q(x)  # [batch_size, seq_length, embed_size]
        K = self.W_k(x)  # [batch_size, seq_length, embed_size]
        V = self.W_v(x)  # [batch_size, seq_length, embed_size]
        
        # 计算注意力分数: Q·K^T
        attn_scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch_size, seq_length, seq_length]
        
        # 缩放: 除以√d_k
        d_k = self.embed_size
        attn_scores = attn_scores / (d_k ** 0.5)
        
        # Softmax获取注意力权重
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 加权求和: attn_weights·V
        output = torch.matmul(attn_weights, V)  # [batch_size, seq_length, embed_size]
        
        return output, attn_weights

# 使用示例
batch_size = 2
seq_length = 4
embed_size = 8
x = torch.randn(batch_size, seq_length, embed_size)  # 随机输入

attention = SelfAttention(embed_size)
output, attn_weights = attention(x)

print("输入形状:", x.shape)  # torch.Size([2, 4, 8])
print("输出形状:", output.shape)  # torch.Size([2, 4, 8])
print("注意力权重形状:", attn_weights.shape)  # torch.Size([2, 4, 4])

输出结果:

输入形状: torch.Size([2, 4, 8])
输出形状: torch.Size([2, 4, 8])
注意力权重形状: torch.Size([2, 4, 4])

6.2 多头注意力实现(Transformer核心)

import torch
import torch.nn as nn


class Attention(nn.Module):
    def __init__(self,
                 dim,  # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


# 示例调用
if __name__ == "__main__":
    # 创建输入张量 [batch_size, num_patches+1, embed_dim]
    batch_size = 2
    num_patches = 196  # 例如ViT中的14x14图像块
    embed_dim = 768  # 与dim参数一致
    x = torch.randn(batch_size, num_patches + 1, embed_dim)

    # 初始化Attention模块
    attention = Attention(
        dim=embed_dim,
        num_heads=8,
        qkv_bias=True,
        attn_drop_ratio=0.1,
        proj_drop_ratio=0.1
    )

    # 前向传播
    output = attention(x)
    # 验证输出形状 (应与输入形状相同)
    print("Input shape:", x.shape)  # torch.Size([2, 197, 768])
    print("Output shape:", output.shape)  # torch.Size([2, 197, 768])

在这里插入图片描述

七、注意力可视化:看看模型在关注什么

7.1 如何可视化注意力

我们可以将注意力权重可视化为热力图,直观地看到模型在处理句子时关注了哪些词。

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import matplotlib.font_manager as fm


def visualize_attention(sentence, attn_weights, head=0, font_path=None):
    """
    可视化注意力权重(优化中文显示版本)
    sentence: 原始句子(词列表)
    attn_weights: 注意力权重 [heads, seq_len, seq_len]
    head: 要可视化的头
    font_path: 字体文件路径(可选,用于更好地显示中文)
    """
    # 设置中文字体支持
    if font_path:
        # 如果指定了字体路径,则使用指定字体
        prop = fm.FontProperties(fname=font_path)
        plt.rcParams['font.family'] = prop.get_name()
    else:
        # 尝试使用系统默认中文字体
        plt.rcParams['font.sans-serif'] = ['SimHei', 'FangSong', 'Microsoft YaHei', 'Arial Unicode MS']
        plt.rcParams['axes.unicode_minus'] = False  # 解决负号'-'显示为方块的问题

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attn_weights[head].detach().numpy(),
        xticklabels=sentence,
        yticklabels=sentence,
        cmap='viridis',
        annot=True,
        fmt=".2f",
        square=True  # 使单元格为正方形
    )
    plt.title(f'注意力头 {head + 1}', fontsize=16)
    plt.xlabel('键 (Key)', fontsize=12)
    plt.ylabel('查询 (Query)', fontsize=12)

    # 旋转标签以便更好地显示
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)

    plt.tight_layout()
    plt.show()


# 示例使用
sentence = ["我", "喜欢", "吃", "苹果"]
# 假设我们有一个4头注意力,取第一个头
attn_weights = torch.rand(4, 4, 4)  # 随机生成注意力权重

print("原始句子:", sentence)
print("注意力权重形状:", attn_weights.shape)

# 可视化第一个注意力头
visualize_attention(sentence, attn_weights, head=0)

在这里插入图片描述

八、不同场景下的QKV:Encoder-Decoder注意力

8.1 Transformer架构中的两种注意力

Transformer包含两种主要的注意力机制:

  1. Encoder Self-Attention
    • Q、K、V都来自Encoder输入
    • 用于理解输入序列内部关系
Encoder输出
生成Query Q
生成Key K
生成Value V
注意力计算
Decoder下一层
  1. Decoder-Encoder Attention
    • Q来自Decoder
    • K和V来自Encoder
    • 让Decoder在生成输出时关注输入的关键部分
Decoder输入
生成Query Q
Encoder输出
生成Key K
生成Value V
注意力计算
Decoder下一层

九、总结与学习建议

9.1 QKV核心要点总结

概念 说明 关键点
Q(Query) “提问者”,表示当前关注点 决定"我想知道什么"
K(Key) “标签”,表示内容特征 决定"这段内容关于什么"
V(Value) “内容”,表示实际信息 决定"提供什么信息"
Self-Attention QKV来自同一输入 捕捉序列内部关系
Multi-Head 多组QKV并行计算 从不同角度理解输入

9.2 学习建议

  1. 动手实践:运行提供的代码,修改参数观察输出变化
  2. 可视化探索:使用BERT等预训练模型查看真实注意力分布
  3. 深入阅读:阅读原始论文《Attention is All You Need》
  4. 扩展学习:了解Transformer变体(如BERT、GPT)中的QKV应用

十、结语

QKV不是神秘的魔法,而是通过简单线性变换从输入数据中生成的——只是这些变换的参数是模型在训练中自动学习得到的。理解QKV是掌握Transformer的关键一步,希望本文的详细解释、流程图和代码能帮助你彻底掌握这一核心概念!

记住:Q是"我想知道什么",K是"这段内容关于什么",V是"这段内容的实际信息"。当你下次看到QKV时,就想象自己在图书馆里找书,这就是注意力机制的精髓!


网站公告

今日签到

点亮在社区的每一天
去签到