Qwen3 中旋转位置编码

发布于:2025-09-13 ⋅ 阅读:(21) ⋅ 点赞:(0)

模拟配置类


class MockConfig:
    def __init__(self):
        self.max_position_embeddings = 2048
        self.rope_theta = 10000.0
        self.hidden_size = 512
        self.num_attention_heads = 8
        self.head_dim = self.hidden_size // self.num_attention_heads # 512 // 8 = 64
        self.rope_scaling = None

Qwen3MoeRotaryEmbedding模块

import torch
import torch.nn as nn

def default_rope_init(config, device=None):
    """默认的RoPE初始化函数"""
    dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size # 64

    inv_freq = 1.0 / (
        config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
    )   # 10000.0 ** (torch.arange(0, 64, 2) / 64) -> 32
    print("inv_freq:",inv_freq.shape)
    return inv_freq.to(device), 1.0  # inv_freq, attention_scaling

ROPE_INIT_FUNCTIONS = {
    "default": default_rope_init,
}


class Qwen3MoeRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: MockConfig, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings  # 2048
        self.original_max_seq_len = config.max_position_embeddings  # 2048

        self.config = config  
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]  # default_rope_init

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)  # 32  1.0
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq  # 32

    @torch.no_grad()
    def forward(self, x, position_ids):    # 2 8 8 64  / 2 8
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                             # 1 32 1 / 2 8 1 -> 2 32 1
        position_ids_expanded = position_ids[:, None, :].float() # 2 1 8

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # 2 8 32
            emb = torch.cat((freqs, freqs), dim=-1) # 2 8 64
            cos = emb.cos() * self.attention_scaling # 2 8 64
            sin = emb.sin() * self.attention_scaling # 2 8 64
        print("cos:",cos)
        print("sin:",sin)
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

示例

config = MockConfig()

rope = Qwen3MoeRotaryEmbedding(config)

batch_size = 2
seq_length = 8
num_heads = config.num_attention_heads  # 8 
head_dim = config.head_dim  # 6

q = torch.randn(batch_size, seq_length, num_heads, head_dim) # 2 8 8 64
k = torch.randn(batch_size, seq_length, num_heads, head_dim)  # 2 8 8 64

position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1)  # # 8 -> 1,8 -> 2,8

cos, sin = rope(q, position_ids)

print(f"\nRoPE输出:")
print(f"  - cos: {cos.shape}")
print(f"  - sin: {sin.shape}")
    
inv_freq: torch.Size([32])
cos: tensor([[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.5403,  0.7318,  0.8460,  ...,  1.0000,  1.0000,  1.0000],
         [-0.4161,  0.0709,  0.4315,  ...,  1.0000,  1.0000,  1.0000],
         ...,
         [ 0.2837, -0.8209, -0.9461,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.9602, -0.2114, -0.9731,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.7539,  0.5114, -0.7004,  ...,  1.0000,  1.0000,  1.0000]],

        [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.5403,  0.7318,  0.8460,  ...,  1.0000,  1.0000,  1.0000],
         [-0.4161,  0.0709,  0.4315,  ...,  1.0000,  1.0000,  1.0000],
         ...,
         [ 0.2837, -0.8209, -0.9461,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.9602, -0.2114, -0.9731,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.7539,  0.5114, -0.7004,  ...,  1.0000,  1.0000,  1.0000]]])
sin: tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 8.4147e-01,  6.8156e-01,  5.3317e-01,  ...,  2.3714e-04,
           1.7783e-04,  1.3335e-04],
         [ 9.0930e-01,  9.9748e-01,  9.0213e-01,  ...,  4.7427e-04,
           3.5566e-04,  2.6670e-04],
         ...,
         [-9.5892e-01, -5.7113e-01,  3.2394e-01,  ...,  1.1857e-03,
           8.8914e-04,  6.6676e-04],
         [-2.7942e-01, -9.7740e-01, -2.3037e-01,  ...,  1.4228e-03,
           1.0670e-03,  8.0011e-04],
         [ 6.5699e-01, -8.5931e-01, -7.1372e-01,  ...,  1.6600e-03,
           1.2448e-03,  9.3346e-04]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 8.4147e-01,  6.8156e-01,  5.3317e-01,  ...,  2.3714e-04,
           1.7783e-04,  1.3335e-04],
         [ 9.0930e-01,  9.9748e-01,  9.0213e-01,  ...,  4.7427e-04,
           3.5566e-04,  2.6670e-04],
         ...,
         [-9.5892e-01, -5.7113e-01,  3.2394e-01,  ...,  1.1857e-03,
           8.8914e-04,  6.6676e-04],
         [-2.7942e-01, -9.7740e-01, -2.3037e-01,  ...,  1.4228e-03,
           1.0670e-03,  8.0011e-04],
         [ 6.5699e-01, -8.5931e-01, -7.1372e-01,  ...,  1.6600e-03,
           1.2448e-03,  9.3346e-04]]])

RoPE输出:
  - cos: torch.Size([2, 8, 64])
  - sin: torch.Size([2, 8, 64])

应用RoPE到查询和键

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # # 2 8 8 64
    cos = cos.unsqueeze(unsqueeze_dim)  # 2 8 64 -> 2 1 8 64
    sin = sin.unsqueeze(unsqueeze_dim)  # 2 8 64 -> 2 1 8 64
    q_embed = (q * cos) + (rotate_half(q) * sin) 
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

q_rotated, k_rotated = apply_rotary_pos_emb(q, k, cos, sin)

print(f"\n应用RoPE后:")
print(f"  - 旋转后的查询 (q_rotated): {q_rotated.shape}")
print(f"  - 旋转后的键 (k_rotated): {k_rotated.shape}")

# 6. 验证RoPE的性质
print(f"\n=== RoPE性质验证 ===")
# 检查形状是否保持一致
assert q_rotated.shape == q.shape, "查询张量形状不一致"
assert k_rotated.shape == k.shape, "键张量形状不一致"

print("✓ 查询和键张量形状保持一致")

# 7. 展示不同位置的RoPE值
print(f"\n=== 不同位置的RoPE值示例 ===")
print("位置0的cos值前5维:", cos[0, 0, :5].tolist())
print("位置0的sin值前5维:", sin[0, 0, :5].tolist())
print("位置3的cos值前5维:", cos[0, 3, :5].tolist())
print("位置3的sin值前5维:", sin[0, 3, :5].tolist())

# 8. 验证正交性 (RoPE保持内积不变)
print(f"\n=== 正交性验证 ===")
# 计算原始查询和旋转后查询的内积
original_inner_prod = torch.sum(q[0, 0, 0, :] * q[0, 1, 0, :])
rotated_inner_prod = torch.sum(q_rotated[0, 0, 0, :] * q_rotated[0, 1, 0, :])

print(f"位置0和1的原始内积: {original_inner_prod:.6f}")
print(f"位置0和1的旋转后内积: {rotated_inner_prod:.6f}")
print(f"差异: {abs(original_inner_prod - rotated_inner_prod):.6f}")



print(f"\n=== 示例完成 ===")
print("RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码")
应用RoPE后:
  - 旋转后的查询 (q_rotated): torch.Size([2, 8, 8, 64])
  - 旋转后的键 (k_rotated): torch.Size([2, 8, 8, 64])

=== RoPE性质验证 ===
✓ 查询和键张量形状保持一致

=== 不同位置的RoPE值示例 ===
位置0的cos值前5维: [1.0, 1.0, 1.0, 1.0, 1.0]
位置0的sin值前5维: [0.0, 0.0, 0.0, 0.0, 0.0]
位置3的cos值前5维: [-0.9899924993515015, -0.6279267072677612, -0.11596616357564926, 0.3009673058986664, 0.5827536582946777]
位置3的sin值前5维: [0.14112000167369843, 0.7782725095748901, 0.9932531714439392, 0.9536344408988953, 0.8126488924026489]

=== 正交性验证 ===
位置0和1的原始内积: -1.464770
位置0和1的旋转后内积: -1.464770
差异: 0.000000

=== 示例完成 ===
RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码