Improve MoE implementation (#841)

This commit is contained in:
Sebastian Raschka
2025-09-22 15:21:06 -05:00
committed by GitHub
parent 20041fb94b
commit e742d8af2c
6 changed files with 177 additions and 250 deletions

View File

@@ -134,14 +134,14 @@ class MoEFeedForward(nn.Module):
super().__init__()
self.num_experts_per_tok = cfg["num_experts_per_tok"]
self.num_experts = cfg["num_experts"]
self.emb_dim = cfg["emb_dim"]
self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"])
meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory)
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
for _ in range(cfg["num_experts"])])
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
for _ in range(cfg["num_experts"])])
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device)
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"])
for _ in range(cfg["num_experts"])])
def forward(self, x):
@@ -149,24 +149,37 @@ class MoEFeedForward(nn.Module):
topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
topk_probs = torch.softmax(topk_scores, dim=-1)
expert_outputs = []
for e in range(self.num_experts):
hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)
out = self.fc3[e](hidden)
expert_outputs.append(out.unsqueeze(-2))
expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)
batch, seq_len, _ = x.shape
x_flat = x.reshape(batch * seq_len, -1)
out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)
gating_probs = torch.zeros_like(scores)
topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)
topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)
for i in range(self.num_experts_per_tok):
indices = topk_indices[..., i:i+1]
prob = topk_probs[..., i:i+1]
gating_probs.scatter_(dim=-1, index=indices, src=prob)
gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)
unique_experts = torch.unique(topk_indices_flat)
# Weighted sum over experts
y = (gating_probs * expert_outputs).sum(dim=-2)
return y
for expert_id_tensor in unique_experts:
expert_id = int(expert_id_tensor.item())
mask = topk_indices_flat == expert_id
if not mask.any():
continue
token_mask = mask.any(dim=-1)
selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)
if selected_idx.numel() == 0:
continue
expert_input = x_flat.index_select(0, selected_idx)
hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)
expert_out = self.fc3[expert_id](hidden)
mask_selected = mask[selected_idx]
slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)
selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)
out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))
return out_flat.reshape(batch, seq_len, self.emb_dim)
class GroupedQueryAttention(nn.Module):