Qwen3_moe模型代码解析
1) 顶层:Qwen3MoeModel.forward(Embedding → 多层解码器 → RMSNorm)
要点
- 位置编码:MoE 是 1D RoPE(
position_ids形状(1,B,S))。 - 掩码:根据配置可能是标准因果或滑窗因果。
- Cache:默认用
DynamicCache(用于增量解码)。
2) 解码器层:Qwen3MoeDecoderLayer.forward(Self-Attn → 残差 → MoE/MLP → 残差)
要点
- MoE 层与普通 MLP 层互斥;MoE 层额外产生
router_logits(被上层输出记录用于负载均衡辅助损失)。
3) 注意力:Qwen3MoeAttention.forward(Q/K/V → RoPE → 注意力 → 合并头)
要点
- 与 Qwen2 系列类似,GQA:
n_heads = num_attention_heads,n_kv = num_key_value_heads,通过repeat_kv对齐。 - 层内对 Q/K 施加 RMSNorm(按头维) 再做 RoPE(这是 Qwen3 MoE 和一些实现的一个小差别)。
4) 稀疏 MoE:Qwen3MoeSparseMoeBlock.forward(Gating → Top-k 路由 → 专家并行 → 汇聚)
要点
- 该实现是token-level routing;每 token 选 top-k 专家;支持
norm_topk_prob对 top-k 权重归一化。 - 通过
index_add_将各专家输出按原 token 位置汇聚。
5) RoPE:Qwen3MoeRotaryEmbedding.forward(1D 位置 → cos/sin)
要点
- 与标准 1D RoPE 一致(没有多模态 3D 拆段)。
dynamic_rope_update允许动态扩展(取决于rope_scaling策略)。
6) 语言建模头:Qwen3MoeForCausalLM.forward(LM Head & Router Loss)
要点
logits_to_keep:只在末 K 个时间步计算lm_head,显存友好。aux_loss:负载均衡损失,鼓励专家使用更均匀。
7) 掩码构造与缓存(顶层)
8) 形状清单(常用变量)
B: batch size;S: 当前序列长度;H: hidden_size;V: vocab_size头部:
n_heads = num_attention_heads,n_kv = num_key_value_heads,d = head_dim = H / n_heads注意力中:
- Q:
(B,n_heads,S,d);K/V:(B,n_kv,S,d)→repeat_kv→(B,n_heads,S,d) - 权重
(B,n_heads,S,S);输出(B,S,H)
- Q:
9) 常见坑与对策
- inputs 选择:
(input_ids is None) XOR (inputs_embeds is not None)必须成立,否则抛错。 - RoPE 维度:
position_ids必须(1,B,S);若用 cache,需要正确设置cache_position使位置连续。 - 滑窗注意力:窗口
W太小会影响长程依赖;太大则近似全因果。确保与训练/推理对齐。 - MoE 路由:
num_experts_per_tok (top_k)影响吞吐与均衡;norm_topk_prob=True时要注意与训练策略匹配。 - 负载均衡损失:
output_router_logits=True时才会收集所有层的router_logits;注意与attention_mask一起计算避免 padding 干扰。 - 精度:注意力
softmax强制float32再 cast 回来,避免数值不稳。
10) 端到端数字化算例(便于核对)
假设:
B=2, S=128, H=4096, V=151936;n_heads=32 → d=128;n_kv=8 → num_key_value_groups=4;sliding_window=None(全因果);use_cache=True首次前向past=None;- 第 4 层是 MoE 层:
num_experts=8, top_k=2, moe_intermediate_size=11008,其它层为密集 MLP:intermediate_size=11008。
流程
inputs_embeds = embed_tokens(input_ids)→(2,128,4096)cache_position = [0..127],position_ids=(1,2,128)causal_mask=(2,1,128,128)上三角 -inf进入第 1 层:
- Q/K/V 线性:
(2,128,4096) → Q:(2,128,4096) K/V:(2,128,1024) - 视图→
Q:(2,32,128,128);K/V:(2,8,128,128)→repeat_kv→(2,32,128,128) - RoPE:
apply_rotary_pos_emb - 注意力:权重
(2,32,128,128)→ 输出(2,32,128,128)→ 合并头(2,128,4096) - 残差 + MLP (SwiGLU):
(2,128,4096)→(2,128,11008)→(2,128,4096)
- Q/K/V 线性:
第 4 层(MoE):
- Gate:
(B*S,H)=(256,4096) → (256,8)→ softmax → top2 - 对被命中专家的 tokens 送入各自
MLP_e:(N_e,4096)→(N_e,11008)→(N_e,4096),乘以各 token 对应权重 index_add_汇聚回(256,4096)→ reshape(2,128,4096)- 残差
router_logits记录(供 loss)
- Gate:
L 层结束 →
RMSNorm→last_hidden_state=(2,128,4096)lm_head:(4096→V)只算末K=32步 →logits=(2,32,V)如有
labels:交叉熵 + 若output_router_logits=True再加aux_loss(乘以router_aux_loss_coef)。