0 导入库
import torch
import torch.nn as nn
import torch.nn.functional as F
1 专家模型
#一个简单的专家模型,可以是任何神经网络架构
class Expert(nn.Module):
def __init__(self, input_size, output_size):
super(Expert, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, x):
return self.fc(x)
2 MOE
class MoE(nn.Module):
def __init__(self, num_experts, input_size, output_size,topk):
super(MoE, self).__init__()
self.num_experts = num_experts
self.topk=topk
self.experts = nn.ModuleList([Expert(input_size, output_size) for _ in range(num_experts)])
#创建多个专家
self.gating_network = nn.Linear(input_size, num_experts)
#门控网络
def forward(self, x):
#假设x的维度是(batch,input_size)
gating_scores = self.gating_network(x)
# 门控网络决定权重 (选择每一个专家的概率)
#输出维度是(batch_size, num_experts)
topk_gate_scores,topk_gate_index=gating_scores.topk(topk,-1)
#选取topk个专家
#(batch_size,topk)
gating_scores_filtered=torch.full_like(gating_scores,fill_value=float("-inf"))
gating_scores_filtered=gating_scores_filtered.scatter(-1,topk_gate_index,topk_gate_scores)
gating_scores_filtered=F.softmax(gating_scores_filtered,dim=-1)
##创建一个全为负无穷的张量 zeros,并将 topk_gate_scores 的值插入到这个张量的对应位置
#(batch_size,num_experts)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
# 专家网络输出
#每个expert的输出维度是 (batch_size, output_size)
#stack沿着第二个维度堆叠,之后expert_outputs的维度是(batch_size,num_experts,output_size)
moe_output = torch.bmm(gating_scores_filtered.unsqueeze(1), expert_outputs).squeeze(1)
# 加权组合专家输出
#gating_scores.unsqueeze(1)——>(batch_size, 1,num_experts)
#torch.bmm(gating_scores.unsqueeze(1), expert_outputs)——>(batch_size,1,output_size)
#moe_output——>(batch_size,output_size)
return moe_output
3 输入举例
input_size = 10
# 输入特征是大小为10的向量
output_size = 5
# 输出大小为5的向量
num_experts = 3
# 3个专家
moe_model = MoE(num_experts, input_size, output_size)
# 初始化MOE模型
input_vector = torch.randn(1, input_size)
# 创建一个输入向量
output_vector = moe_model(input_vector)
# 前向传递
print(output_vector.shape,output_vector)
# 打印输出
'''
torch.Size([1, 5]) tensor([[ 2.7343e-04, 4.0966e-01, -3.6634e-01, -8.9064e-01, 4.0759e-01]],
grad_fn=<SqueezeBackward1>)
'''