核心思想
- 用多个较小的 前馈网络 (FFN) 替换原本的大 FFN 层。
- 每个 token 只经过少数几个专家 (top-k),而不是所有专家,提升计算效率。
- 另有 共享专家 (shared experts),对所有 token 都进行计算,确保模型稳定性和表达能力。
- 非激活专家不会被调用,因此不会参与前向和反向传播,减少计算量。
实现方式
- 门控 (Gating):根据输入特征选择 top-k 专家并分配权重。
- 路由与计算:激活专家处理对应 token,并加权聚合输出。
- 分布式 MoE:专家跨设备部署,输入 token 通过通信操作分发到对应专家,再聚合结果。
负载均衡 (Load Balancing)
- 如果路由过于集中在少数专家,会导致:
- 一部分专家过度训练,另一部分训练不足。
- 计算资源浪费,影响模型效果。
- 常见解决方法:
- 在打分阶段加入噪声,使专家选择更均匀,特别在训练早期。
- 添加负载均衡损失在,惩罚专家选择过度不均。
- 训练时随机禁用部分专家,防止过度依赖。
- 在打分结果中引入动态偏置,帮助提升专家利用率(Deepseek的无辅助损失函数)。
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
Deepseek-V3源码Moe结构简化版实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class Gate(nn.Module):
def __init__(self, dim: int, n_experts: int, topk: int, score_func: str = "softmax", route_scale: float = 1.0):
super().__init__()
self.dim = dim
self.n_experts = n_experts
self.topk = topk
self.score_func = score_func
self.route_scale = route_scale
self.weight = nn.Parameter(torch.empty(n_experts, dim))
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor):
# (batch, dim) @ (dim, n_experts)^T = (batch, n_experts)
scores = x @ self.weight.t()
if self.score_func == "softmax":
scores = scores.softmax(dim=-1)
else:
scores = scores.sigmoid()
# 选择 top-k 专家
topk_scores, topk_indices = torch.topk(scores, self.topk, dim=-1)
# 归一化
if self.score_func == "sigmoid":
topk_scores = topk_scores / (topk_scores.sum(dim=-1, keepdim=True) + 1e-9)
topk_scores = topk_scores * self.route_scale
return topk_scores, topk_indices
class Expert(nn.Module):
def __init__(self, dim: int, inter_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, inter_dim)
self.w2 = nn.Linear(inter_dim, dim)
def forward(self, x: torch.Tensor):
return self.w2(F.silu(self.w1(x)))
class MoE(nn.Module):
def __init__(self, dim: int, n_experts: int, topk: int, inter_dim: int):
super().__init__()
self.dim = dim
self.gate = Gate(dim, n_experts, topk)
self.experts = nn.ModuleList([Expert(dim, inter_dim) for _ in range(n_experts)])
def forward(self, x: torch.Tensor):
shape = x.size()
x = x.view(-1, self.dim)
# 得到选择的专家及权重
weights, indices = self.gate(x) # (batch, topk), (batch, topk)
y = torch.zeros_like(x)
for k in range(self.gate.topk):
expert_idx = indices[:, k]
expert_weight = weights[:, k]
for i in range(self.gate.n_experts):
mask = (expert_idx == i)
if mask.any():
y[mask] += self.experts[i](x[mask]) * expert_weight[mask, None]
return y.view(shape)