前言:
MLA(Multi-head Latent Attention,多头潜在注意力)旨在提高推理效率和降低计算资源的消。MLA的核心思想在于通过信息转移来优化KV缓存的使用
MLA的技术特点主要包括:
- KV压缩与潜在变量:将键(Key)和值(Value)联合压缩为低维潜在向量,显著减少推理时的KV缓存,降低内存占用。计算时通过升维恢复原始信息,平衡压缩效率与计算精度。
- 低秩降维技术:对查询(Queries)进行低秩压缩(降维后再升维),减少训练中的激活内存(activation memory),但需注意此操作不影响KV缓存。
- 动态序列处理:针对可变长度输入序列优化,支持高效处理不同长度的句子(如长文本对话场景 ROPE)。
目录
- KV-cache
- MLA 模型简介
- MLA+ROPE
- MLA 数学原理
- PyTorh 代码
一 KV-cache
1.1 MHA (多头注意力)
1.2 KV-cache
在自回归生成过程中,每个新生成的token都会依赖于之前所有token的信息,这就需要在生成每个新token时重新计算整个序列的自注意力。然而,这种计算方式非常低效,因为大量重复的计算被浪费在了已经生成过的token上。
为了缩短inference time, KV-Cache机制正是为了解决这一问题而提出的。它的工作原理是在生成过程中,将已经计算过的键和值向量存储在缓存中,这样在生成后续token时,可以直接从缓存中获取之前token的键和值,而不需要重新计算。具体来说,当生成一个新的token时,模型只需要计算这个新token的查询向量,并与缓存中的键向量计算注意力得分,然后使用这些得分和缓存中的值向量来计算新token的输出表示.
KV-Cache 的大小取决于以下参数:
: 注意力头数,每层的注意力头数量。
: 每个注意力头的维度,每个注意力头的 Key 和 Value 的维度。
l: 输入的层数模
则每个token 对应的 KV-cache 为
不同注意力机制对应的kv-cache
二 MLA(Multi-Layer Adaptation)
多头潜在注意力 (MLA) 是一种新的注意力机制,它通过将键和值压缩为一个较小的共享表示(称为潜在向量)来实现这一点。这可以减小 KV 缓存的大小,同时保持甚至提高性能。
MLA 引入了两项关键创新:
- Low-Rank Key-Value Compression
- Decoupled Rotary Position Embedding (RoPE)
2.1 MLA 架构
2.2 计算流程
参考:
MLA reduces the KV cache size by compressing the keys and values into a smaller latent vector and decoupling the position information (RoPE). Here’s how the cache size is calculated.
三 Decoupled Rotary Position Embedding (RoPE)
旋转位置编码(Rotary Position Embedding, RoPE)是一种用于编码序列中标记位置的技术。然而,RoPE是位置敏感的,这意味着它依赖于每个标记的具体位置。这在使用低秩压缩时会产生问题,因为位置信息会被混入压缩后的键(keys)和值(values)中,导致在推理过程中难以高效地重用它们。为了解决ROPE问题,使用了下面架构
参考:
KV-cache 的大小(包括了ROPE 部分)
四 PyTorch 代码
常用超参数
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 15 18:24:47 2025
@author: cxf
"""
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 13 13:51:48 2025
@author: chengxf2
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Config:
def __init__(self):
self.vocab_size = 32000
#词向量的维度
self.d_model = 1024
#number of attention heads
self.n_heads = 8
#dDmension of per head =64
self.d_head = self.d_model//self.n_heads
#ROPE dimension, typically 128
self.d_rope = self.d_head//2
#compression dimension KV_cache <<n_head*d_h
self.d_kv_cache = 4*self.d_head
self.seq_len = 10
self.batch_size = 1
#256
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
#Dimension must be even for Rotary Embedding
assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
self.dim = dim//2
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len):
t = torch.arange(seq_len)
freqs = torch.einsum("i,j->ij",t, self.inv_freq)
output = torch.cat((freqs, freqs), dim=-1)
return output
def rotate_half(x):
"""
Apply rotary embeddings to the first half of x.
"""
x1 ,x2 = x.chunk(2,dim=-1)
output = torch.cat((-x2,x1),dim=-1)
return output
def apply_rotary(x, cos, sin):
"""
Apply rotary embeddings to the first half of x.
"""
#x.shape batch_size, seq_len, head, d_h
# Split x into two parts: one for rotary embeddings and the other untouched x_rot, x_base = x.split(cos.shape[-1],dim=-1)
print("\n apply _rotary ",x.shape)
print("\n cos x ",cos.shape, x.shape)
x_rot, x_base = x.split(cos.shape[-1],dim=-1)
x_rot =(x_rot*cos)+(rotate_half(x_rot)*sin)
output = torch.cat([x_rot,x_base],dim=-1)
return output
config = Config()
class MemoryOptimizedMLA(nn.Module):
def __init__(self):
super().__init__()
self.d_head = config.d_head
self.d_split = config.d_model-config.d_rope
#down-projection
self.W_DQ = nn.Linear(config.d_model, config.d_kv_cache)
self.W_DKV = nn.Linear(config.d_model, config.d_kv_cache)
print("\n kv cache size ",config.d_kv_cache)
# RoPE
self.W_q_rope = nn.Linear(config.d_kv_cache, config.d_rope)
self.W_k_rope = nn.Linear(config.d_model, config.d_rope)
#step2: Up Projections
self.W_UQ = nn.Linear(config.d_kv_cache, self.d_split)
self.W_UK = nn.Linear(config.d_kv_cache, self.d_split)
self.W_UV = nn.Linear(config.d_kv_cache, config.d_model)
#rotary Embedding
self.rotary = RotaryEmbedding(config.d_rope//config.n_heads)
#step3 output
self.output = nn.Linear(config.d_model, config.d_model)
def forward(self, x):
batch_size, seq_len, d_model = x.shape
print("\n bat_size %d seq_len: %d d_model: %d "%(batch_size, seq_len, d_model))
#step1: down-projection Compression
print("\n step1 : down projection")
#query compression
q_c = self.W_DQ(x)
kv_cache = self.W_DKV(x)
#print("\n kv-cache",kv_cache.shape,"\t q_c",q_c.shape)
#Apply RoPE
print("\n step2 : apply ROPE ")
rotary_emb = self.rotary(seq_len)
cos = torch.cos(rotary_emb).view(1, seq_len, 1, -1)
sin = torch.sin(rotary_emb).view(1, seq_len, 1, -1)
q_rot = self.W_q_rope(q_c)
q_rot = q_rot.view(batch_size, seq_len, config.n_heads, -1)
q_rot = apply_rotary(q_rot, cos, sin)
k_rot_cache = self.W_k_rope(x)
k_rot_cache = k_rot_cache.view(batch_size, seq_len, config.n_heads,-1)
k_rot_cache = apply_rotary(k_rot_cache,cos, sin)
#up-projection
print("\n step3 : up projection ")
q_base = self.W_UQ(q_c).view(batch_size, seq_len, config.n_heads, -1)
k = self.W_UK(kv_cache).view(batch_size, seq_len, config.n_heads, -1)
v = self.W_UV(kv_cache).view(batch_size, seq_len, config.n_heads, -1)
# concate
q = torch.cat([q_base, q_rot], dim=-1)
k = torch.cat([k, k_rot_cache], dim=-1)
# Attention computation
scores = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(self.d_head)
attn = F.softmax(scores, dim=-1)
out = torch.einsum("bhqk,bkhd->bqhd", attn, v)
out = self.output(out.contiguous().view(batch_size, seq_len, -1))
output = out, (kv_cache, k_rot_cache)
print("\n output ",out.shape)
return output
net= MemoryOptimizedMLA()
x = torch.randn((config.batch_size, config.seq_len, config.d_model))
out = net(x)
DeepSeek's Multi-Head Latent Attention - Lior Sinai