mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Improve MoE implementation (#841)
This commit is contained in:
committed by
GitHub
parent
20041fb94b
commit
e742d8af2c
@@ -134,14 +134,14 @@ class MoEFeedForward(nn.Module):
|
||||
super().__init__()
|
||||
self.num_experts_per_tok = cfg["num_experts_per_tok"]
|
||||
self.num_experts = cfg["num_experts"]
|
||||
self.emb_dim = cfg["emb_dim"]
|
||||
self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory)
|
||||
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
|
||||
for _ in range(cfg["num_experts"])])
|
||||
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
|
||||
for _ in range(cfg["num_experts"])])
|
||||
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"])
|
||||
for _ in range(cfg["num_experts"])])
|
||||
|
||||
def forward(self, x):
|
||||
@@ -149,24 +149,37 @@ class MoEFeedForward(nn.Module):
|
||||
topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
|
||||
topk_probs = torch.softmax(topk_scores, dim=-1)
|
||||
|
||||
expert_outputs = []
|
||||
for e in range(self.num_experts):
|
||||
hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)
|
||||
out = self.fc3[e](hidden)
|
||||
expert_outputs.append(out.unsqueeze(-2))
|
||||
expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)
|
||||
batch, seq_len, _ = x.shape
|
||||
x_flat = x.reshape(batch * seq_len, -1)
|
||||
out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)
|
||||
|
||||
gating_probs = torch.zeros_like(scores)
|
||||
topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)
|
||||
topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)
|
||||
|
||||
for i in range(self.num_experts_per_tok):
|
||||
indices = topk_indices[..., i:i+1]
|
||||
prob = topk_probs[..., i:i+1]
|
||||
gating_probs.scatter_(dim=-1, index=indices, src=prob)
|
||||
gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)
|
||||
unique_experts = torch.unique(topk_indices_flat)
|
||||
|
||||
# Weighted sum over experts
|
||||
y = (gating_probs * expert_outputs).sum(dim=-2)
|
||||
return y
|
||||
for expert_id_tensor in unique_experts:
|
||||
expert_id = int(expert_id_tensor.item())
|
||||
mask = topk_indices_flat == expert_id
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
token_mask = mask.any(dim=-1)
|
||||
selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)
|
||||
if selected_idx.numel() == 0:
|
||||
continue
|
||||
|
||||
expert_input = x_flat.index_select(0, selected_idx)
|
||||
hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)
|
||||
expert_out = self.fc3[expert_id](hidden)
|
||||
|
||||
mask_selected = mask[selected_idx]
|
||||
slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)
|
||||
selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)
|
||||
|
||||
out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))
|
||||
|
||||
return out_flat.reshape(batch, seq_len, self.emb_dim)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
|
||||
@@ -215,14 +215,14 @@ class MoEFeedForward(nn.Module):
|
||||
super().__init__()
|
||||
self.num_experts_per_tok = cfg["num_experts_per_tok"]
|
||||
self.num_experts = cfg["num_experts"]
|
||||
self.emb_dim = cfg["emb_dim"]
|
||||
self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory)
|
||||
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
|
||||
for _ in range(cfg["num_experts"])])
|
||||
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
|
||||
for _ in range(cfg["num_experts"])])
|
||||
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"])
|
||||
for _ in range(cfg["num_experts"])])
|
||||
|
||||
def forward(self, x):
|
||||
@@ -230,24 +230,37 @@ class MoEFeedForward(nn.Module):
|
||||
topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
|
||||
topk_probs = torch.softmax(topk_scores, dim=-1)
|
||||
|
||||
expert_outputs = []
|
||||
for e in range(self.num_experts):
|
||||
hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)
|
||||
out = self.fc3[e](hidden)
|
||||
expert_outputs.append(out.unsqueeze(-2))
|
||||
expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)
|
||||
batch, seq_len, _ = x.shape
|
||||
x_flat = x.reshape(batch * seq_len, -1)
|
||||
out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)
|
||||
|
||||
gating_probs = torch.zeros_like(scores)
|
||||
topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)
|
||||
topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)
|
||||
|
||||
for i in range(self.num_experts_per_tok):
|
||||
indices = topk_indices[..., i:i+1]
|
||||
prob = topk_probs[..., i:i+1]
|
||||
gating_probs.scatter_(dim=-1, index=indices, src=prob)
|
||||
gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)
|
||||
unique_experts = torch.unique(topk_indices_flat)
|
||||
|
||||
# Weighted sum over experts
|
||||
y = (gating_probs * expert_outputs).sum(dim=-2)
|
||||
return y
|
||||
for expert_id_tensor in unique_experts:
|
||||
expert_id = int(expert_id_tensor.item())
|
||||
mask = topk_indices_flat == expert_id
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
token_mask = mask.any(dim=-1)
|
||||
selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)
|
||||
if selected_idx.numel() == 0:
|
||||
continue
|
||||
|
||||
expert_input = x_flat.index_select(0, selected_idx)
|
||||
hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)
|
||||
expert_out = self.fc3[expert_id](hidden)
|
||||
|
||||
mask_selected = mask[selected_idx]
|
||||
slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)
|
||||
selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)
|
||||
|
||||
out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))
|
||||
|
||||
return out_flat.reshape(batch, seq_len, self.emb_dim)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
@@ -500,7 +513,7 @@ def load_weights_into_qwen(model, param_config, params):
|
||||
)
|
||||
|
||||
# Feedforward weights
|
||||
if "num_experts" in param_config:
|
||||
if param_config.get("num_experts", 0) > 0:
|
||||
# Load router (gating) weights
|
||||
block.ff.gate.weight = assign(
|
||||
block.ff.gate.weight,
|
||||
@@ -525,10 +538,6 @@ def load_weights_into_qwen(model, param_config, params):
|
||||
params[f"{prefix}.down_proj.weight"],
|
||||
f"{prefix}.down_proj.weight"
|
||||
)
|
||||
# After assigning weights, move the expert layers from meta to CPU
|
||||
block.ff.fc1[e] = block.ff.fc1[e].to("cpu")
|
||||
block.ff.fc2[e] = block.ff.fc2[e].to("cpu")
|
||||
block.ff.fc3[e] = block.ff.fc3[e].to("cpu")
|
||||
|
||||
else:
|
||||
block.ff.fc1.weight = assign(
|
||||
|
||||
@@ -11,6 +11,7 @@ from llms_from_scratch.qwen3 import (
|
||||
QWEN_CONFIG_06_B,
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer,
|
||||
MoEFeedForward,
|
||||
RMSNorm,
|
||||
)
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
||||
@@ -113,6 +114,36 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
|
||||
"Expected MoEFeedForward in at least one transformer block"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_moe_forward_matches_reference(dummy_cfg_moe):
|
||||
torch.manual_seed(0)
|
||||
moe = MoEFeedForward(dummy_cfg_moe)
|
||||
x = torch.randn(2, 5, dummy_cfg_moe["emb_dim"])
|
||||
|
||||
scores = moe.gate(x)
|
||||
topk_scores, topk_indices = torch.topk(scores, moe.num_experts_per_tok, dim=-1)
|
||||
topk_probs = torch.softmax(topk_scores, dim=-1)
|
||||
|
||||
expert_outputs = []
|
||||
for e in range(moe.num_experts):
|
||||
hidden = torch.nn.functional.silu(moe.fc1[e](x)) * moe.fc2[e](x)
|
||||
out = moe.fc3[e](hidden)
|
||||
expert_outputs.append(out.unsqueeze(-2))
|
||||
expert_outputs = torch.cat(expert_outputs, dim=-2)
|
||||
|
||||
gating_probs = torch.zeros_like(scores)
|
||||
for i in range(moe.num_experts_per_tok):
|
||||
indices = topk_indices[..., i:i+1]
|
||||
prob = topk_probs[..., i:i+1]
|
||||
gating_probs.scatter_(dim=-1, index=indices, src=prob)
|
||||
gating_probs = gating_probs.unsqueeze(-1)
|
||||
|
||||
expected = (gating_probs * expert_outputs).sum(dim=-2)
|
||||
|
||||
actual = moe(x)
|
||||
torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
|
||||
def test_qwen3_kvcache_equivalence(cfg_name, request):
|
||||
|
||||
Reference in New Issue
Block a user