diff --git a/.gitignore b/.gitignore index 03abb4c..d89911d 100644 --- a/.gitignore +++ b/.gitignore @@ -83,6 +83,7 @@ gemma-3-270m-it/ Qwen3-0.6B-Base/ Qwen3-0.6B/ tokenizer-base.json +tokenizer-reasoning.json tokenizer.json # Datasets diff --git a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb index 42c5f63..3963970 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb @@ -150,79 +150,52 @@ " super().__init__()\n", " self.num_experts_per_tok = cfg[\"num_experts_per_tok\"]\n", " self.num_experts = cfg[\"num_experts\"]\n", + " self.emb_dim = cfg[\"emb_dim\"]\n", " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # meta device to reduce memory pressure when initializing the model before loading weights\n", - " meta_device = torch.device(\"meta\")\n", - " self.fc1 = nn.ModuleList([\n", - " nn.Linear(\n", - " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", - " bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])]\n", - " )\n", - " self.fc2 = nn.ModuleList([\n", - " nn.Linear(\n", - " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", - " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", - " )\n", - " for _ in range(cfg[\"num_experts\"])]\n", - " )\n", - " self.fc3 = nn.ModuleList([\n", - " nn.Linear(\n", - " cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n", - " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", - " )\n", - " for _ in range(cfg[\"num_experts\"])]\n", - " )\n", + " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " for _ in range(cfg[\"num_experts\"])])\n", + " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " for _ in range(cfg[\"num_experts\"])])\n", + " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", + " for _ in range(cfg[\"num_experts\"])])\n", "\n", " def forward(self, x):\n", - " b, seq_len, embed_dim = x.shape\n", " scores = self.gate(x) # (b, seq_len, num_experts)\n", " topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", " topk_probs = torch.softmax(topk_scores, dim=-1)\n", - " \n", - " expert_outputs = []\n", - " for e in range(self.num_experts):\n", - " hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)\n", - " out = self.fc3[e](hidden)\n", - " expert_outputs.append(out.unsqueeze(-2))\n", - " expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)\n", "\n", - " gating_probs = torch.zeros_like(scores)\n", + " batch, seq_len, _ = x.shape\n", + " x_flat = x.reshape(batch * seq_len, -1)\n", + " out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)\n", "\n", - " for i in range(self.num_experts_per_tok):\n", - " indices = topk_indices[..., i:i+1]\n", - " prob = topk_probs[..., i:i+1]\n", - " gating_probs.scatter_(dim=-1, index=indices, src=prob)\n", - " gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)\n", - " \n", - " # Weighted sum over experts\n", - " y = (gating_probs * expert_outputs).sum(dim=-2)\n", - " return y\n", + " topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)\n", + " topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)\n", "\n", + " unique_experts = torch.unique(topk_indices_flat)\n", "\n", - " # For some reason, the version below is slower than the naive version\n", - " # above that computes all experts, even the unused ones\n", + " for expert_id_tensor in unique_experts:\n", + " expert_id = int(expert_id_tensor.item())\n", + " mask = topk_indices_flat == expert_id\n", + " if not mask.any():\n", + " continue\n", "\n", - " # def forward(self, x):\n", - " # scores = self.gate(x) # (b, seq_len, num_experts)\n", - " # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", - " # topk_probs = torch.softmax(topk_scores, dim=-1)\n", - " # y = torch.zeros_like(x)\n", - " #\n", - " # for i in range(self.num_experts_per_tok):\n", - " # # expert_indices is (b, seq_len) with values in [0, num_experts)\n", - " # expert_indices = topk_indices[..., i]\n", - " # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n", - " #\n", - " # # For each expert, process only the tokens assigned to it\n", - " # for e in range(self.num_experts):\n", - " # mask = (expert_indices == e) # (b, seq_len) boolean mask\n", - " # if mask.any():\n", - " # selected = x[mask] # (num_tokens_e, emb_dim)\n", - " # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n", - " # y[mask] += prob[mask] * out\n", - " # return y" + " token_mask = mask.any(dim=-1)\n", + " selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)\n", + " if selected_idx.numel() == 0:\n", + " continue\n", + "\n", + " expert_input = x_flat.index_select(0, selected_idx)\n", + " hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)\n", + " expert_out = self.fc3[expert_id](hidden)\n", + "\n", + " mask_selected = mask[selected_idx]\n", + " slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)\n", + " selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)\n", + "\n", + " out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))\n", + "\n", + " return out_flat.reshape(batch, seq_len, self.emb_dim)" ] }, { @@ -829,7 +802,7 @@ " )\n", "\n", " # Feedforward weights\n", - " if \"num_experts\" in param_config:\n", + " if \"num_experts\" in param_config and param_config[\"num_experts\"] > 0:\n", " # Load router (gating) weights\n", " block.ff.gate.weight = assign(\n", " block.ff.gate.weight,\n", @@ -854,10 +827,6 @@ " params[f\"{prefix}.down_proj.weight\"],\n", " f\"{prefix}.down_proj.weight\"\n", " )\n", - " # After assigning weights, move the expert layers from meta to CPU\n", - " block.ff.fc1[e] = block.ff.fc1[e].to(\"cpu\")\n", - " block.ff.fc2[e] = block.ff.fc2[e].to(\"cpu\")\n", - " block.ff.fc3[e] = block.ff.fc3[e].to(\"cpu\")\n", "\n", " else:\n", " block.ff.fc1.weight = assign(\n", diff --git a/ch05/11_qwen3/standalone-qwen3-moe.ipynb b/ch05/11_qwen3/standalone-qwen3-moe.ipynb index 6e845fb..bb2de06 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe.ipynb @@ -89,8 +89,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "huggingface_hub version: 0.34.3\n", - "tokenizers version: 0.21.4\n", + "huggingface_hub version: 0.35.0\n", + "tokenizers version: 0.22.1\n", "torch version: 2.7.1+cu128\n" ] } @@ -150,79 +150,52 @@ " super().__init__()\n", " self.num_experts_per_tok = cfg[\"num_experts_per_tok\"]\n", " self.num_experts = cfg[\"num_experts\"]\n", + " self.emb_dim = cfg[\"emb_dim\"]\n", " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # meta device to reduce memory pressure when initializing the model before loading weights\n", - " meta_device = torch.device(\"meta\")\n", - " self.fc1 = nn.ModuleList([\n", - " nn.Linear(\n", - " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", - " bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", - " for _ in range(cfg[\"num_experts\"])]\n", - " )\n", - " self.fc2 = nn.ModuleList([\n", - " nn.Linear(\n", - " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", - " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", - " )\n", - " for _ in range(cfg[\"num_experts\"])]\n", - " )\n", - " self.fc3 = nn.ModuleList([\n", - " nn.Linear(\n", - " cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n", - " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", - " )\n", - " for _ in range(cfg[\"num_experts\"])]\n", - " )\n", + " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " for _ in range(cfg[\"num_experts\"])])\n", + " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " for _ in range(cfg[\"num_experts\"])])\n", + " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n", + " for _ in range(cfg[\"num_experts\"])])\n", "\n", " def forward(self, x):\n", - " b, seq_len, embed_dim = x.shape\n", " scores = self.gate(x) # (b, seq_len, num_experts)\n", " topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", " topk_probs = torch.softmax(topk_scores, dim=-1)\n", - " \n", - " expert_outputs = []\n", - " for e in range(self.num_experts):\n", - " hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)\n", - " out = self.fc3[e](hidden)\n", - " expert_outputs.append(out.unsqueeze(-2))\n", - " expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)\n", "\n", - " gating_probs = torch.zeros_like(scores)\n", + " batch, seq_len, _ = x.shape\n", + " x_flat = x.reshape(batch * seq_len, -1)\n", + " out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)\n", "\n", - " for i in range(self.num_experts_per_tok):\n", - " indices = topk_indices[..., i:i+1]\n", - " prob = topk_probs[..., i:i+1]\n", - " gating_probs.scatter_(dim=-1, index=indices, src=prob)\n", - " gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)\n", - " \n", - " # Weighted sum over experts\n", - " y = (gating_probs * expert_outputs).sum(dim=-2)\n", - " return y\n", + " topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)\n", + " topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)\n", "\n", + " unique_experts = torch.unique(topk_indices_flat)\n", "\n", - " # For some reason, the version below is slower than the naive version\n", - " # above that computes all experts, even the unused ones\n", + " for expert_id_tensor in unique_experts:\n", + " expert_id = int(expert_id_tensor.item())\n", + " mask = topk_indices_flat == expert_id\n", + " if not mask.any():\n", + " continue\n", "\n", - " # def forward(self, x):\n", - " # scores = self.gate(x) # (b, seq_len, num_experts)\n", - " # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n", - " # topk_probs = torch.softmax(topk_scores, dim=-1)\n", - " # y = torch.zeros_like(x)\n", - " #\n", - " # for i in range(self.num_experts_per_tok):\n", - " # # expert_indices is (b, seq_len) with values in [0, num_experts)\n", - " # expert_indices = topk_indices[..., i]\n", - " # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n", - " #\n", - " # # For each expert, process only the tokens assigned to it\n", - " # for e in range(self.num_experts):\n", - " # mask = (expert_indices == e) # (b, seq_len) boolean mask\n", - " # if mask.any():\n", - " # selected = x[mask] # (num_tokens_e, emb_dim)\n", - " # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n", - " # y[mask] += prob[mask] * out\n", - " # return y" + " token_mask = mask.any(dim=-1)\n", + " selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)\n", + " if selected_idx.numel() == 0:\n", + " continue\n", + "\n", + " expert_input = x_flat.index_select(0, selected_idx)\n", + " hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)\n", + " expert_out = self.fc3[expert_id](hidden)\n", + "\n", + " mask_selected = mask[selected_idx]\n", + " slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)\n", + " selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)\n", + "\n", + " out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))\n", + "\n", + " return out_flat.reshape(batch, seq_len, self.emb_dim)" ] }, { @@ -582,10 +555,10 @@ { "data": { "text/plain": [ - "tensor([[[nan, nan, nan, ..., nan, nan, nan],\n", - " [nan, nan, nan, ..., nan, nan, nan],\n", - " [nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',\n", - " dtype=torch.bfloat16, grad_fn=)" + "tensor([[[ 0.3223, -0.0562, 0.2490, ..., 0.4551, -0.0542, 0.8242],\n", + " [ 0.0688, 0.0786, -0.0312, ..., 0.6406, -0.9141, 0.8672],\n", + " [-0.6172, 0.4121, 0.3750, ..., 0.1699, -0.2500, 0.6953]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)" ] }, "execution_count": 12, @@ -771,7 +744,7 @@ " )\n", "\n", " # Feedforward weights\n", - " if \"num_experts\" in param_config:\n", + " if \"num_experts\" in param_config and param_config[\"num_experts\"] > 0:\n", " # Load router (gating) weights\n", " block.ff.gate.weight = assign(\n", " block.ff.gate.weight,\n", @@ -796,10 +769,6 @@ " params[f\"{prefix}.down_proj.weight\"],\n", " f\"{prefix}.down_proj.weight\"\n", " )\n", - " # After assigning weights, move the expert layers from meta to CPU\n", - " block.ff.fc1[e] = block.ff.fc1[e].to(\"cpu\")\n", - " block.ff.fc2[e] = block.ff.fc2[e].to(\"cpu\")\n", - " block.ff.fc3[e] = block.ff.fc3[e].to(\"cpu\")\n", "\n", " else:\n", " block.ff.fc1.weight = assign(\n", @@ -863,12 +832,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "488c832145db4dd4848aa67d54a33f0d", + "model_id": "acf19bb84d754884821e1794cedb25a4", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Fetching 27 files: 0%| | 0/27 [00:00 right:\n", - " return -1\n", - " \n", - " # Calculate middle index\n", - " mid = left + (right - left) // 2\n", - " \n", - " if arr[mid] == target:\n", - " return mid\n", - " elif arr[mid] < target:\n", - " return binary_search_recursive(arr, target, mid + 1, right)\n", - " else:\n", - " return binary_search_recursive(arr, target, left, mid - 1)\n", - "```\n", - "\n", - "## Enhanced Version with Additional Features\n", - "\n", - "```python\n", - "def binary_search_enhanced(arr, target, find_first=True):\n", - " \"\"\"\n", - " Enhanced binary search that can find first or last occurrence\n", - " of a target in case of duplicates\n", - " \n", - " Args:\n", - " arr: Sorted list of elements\n", - " target: Element to search for\n", - " find_first: If True, find" + " while left" ] } ], @@ -1179,7 +1083,7 @@ "for token in generate_text_basic_stream(\n", " model=model,\n", " token_ids=input_token_ids_tensor,\n", - " max_new_tokens=500,\n", + " max_new_tokens=100, # Cut-off after 100 tokens because non-kv variant is very slow\n", " # eos_token_id=tokenizer.eos_token_id\n", "):\n", " token_id = token.squeeze(0).tolist()\n", diff --git a/pkg/llms_from_scratch/kv_cache/qwen3.py b/pkg/llms_from_scratch/kv_cache/qwen3.py index 4960234..9ca06eb 100644 --- a/pkg/llms_from_scratch/kv_cache/qwen3.py +++ b/pkg/llms_from_scratch/kv_cache/qwen3.py @@ -134,14 +134,14 @@ class MoEFeedForward(nn.Module): super().__init__() self.num_experts_per_tok = cfg["num_experts_per_tok"] self.num_experts = cfg["num_experts"] + self.emb_dim = cfg["emb_dim"] self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"]) - meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory) - self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device) + self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"]) for _ in range(cfg["num_experts"])]) - self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device) + self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"]) for _ in range(cfg["num_experts"])]) - self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device) + self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"]) for _ in range(cfg["num_experts"])]) def forward(self, x): @@ -149,24 +149,37 @@ class MoEFeedForward(nn.Module): topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1) topk_probs = torch.softmax(topk_scores, dim=-1) - expert_outputs = [] - for e in range(self.num_experts): - hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x) - out = self.fc3[e](hidden) - expert_outputs.append(out.unsqueeze(-2)) - expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim) + batch, seq_len, _ = x.shape + x_flat = x.reshape(batch * seq_len, -1) + out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype) - gating_probs = torch.zeros_like(scores) + topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok) + topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok) - for i in range(self.num_experts_per_tok): - indices = topk_indices[..., i:i+1] - prob = topk_probs[..., i:i+1] - gating_probs.scatter_(dim=-1, index=indices, src=prob) - gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1) + unique_experts = torch.unique(topk_indices_flat) - # Weighted sum over experts - y = (gating_probs * expert_outputs).sum(dim=-2) - return y + for expert_id_tensor in unique_experts: + expert_id = int(expert_id_tensor.item()) + mask = topk_indices_flat == expert_id + if not mask.any(): + continue + + token_mask = mask.any(dim=-1) + selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1) + if selected_idx.numel() == 0: + continue + + expert_input = x_flat.index_select(0, selected_idx) + hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input) + expert_out = self.fc3[expert_id](hidden) + + mask_selected = mask[selected_idx] + slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True) + selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1) + + out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1)) + + return out_flat.reshape(batch, seq_len, self.emb_dim) class GroupedQueryAttention(nn.Module): diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py index 5a0fc50..214b47a 100644 --- a/pkg/llms_from_scratch/qwen3.py +++ b/pkg/llms_from_scratch/qwen3.py @@ -215,14 +215,14 @@ class MoEFeedForward(nn.Module): super().__init__() self.num_experts_per_tok = cfg["num_experts_per_tok"] self.num_experts = cfg["num_experts"] + self.emb_dim = cfg["emb_dim"] self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"]) - meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory) - self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device) + self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"]) for _ in range(cfg["num_experts"])]) - self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device) + self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"]) for _ in range(cfg["num_experts"])]) - self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device) + self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"]) for _ in range(cfg["num_experts"])]) def forward(self, x): @@ -230,24 +230,37 @@ class MoEFeedForward(nn.Module): topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1) topk_probs = torch.softmax(topk_scores, dim=-1) - expert_outputs = [] - for e in range(self.num_experts): - hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x) - out = self.fc3[e](hidden) - expert_outputs.append(out.unsqueeze(-2)) - expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim) + batch, seq_len, _ = x.shape + x_flat = x.reshape(batch * seq_len, -1) + out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype) - gating_probs = torch.zeros_like(scores) + topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok) + topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok) - for i in range(self.num_experts_per_tok): - indices = topk_indices[..., i:i+1] - prob = topk_probs[..., i:i+1] - gating_probs.scatter_(dim=-1, index=indices, src=prob) - gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1) + unique_experts = torch.unique(topk_indices_flat) - # Weighted sum over experts - y = (gating_probs * expert_outputs).sum(dim=-2) - return y + for expert_id_tensor in unique_experts: + expert_id = int(expert_id_tensor.item()) + mask = topk_indices_flat == expert_id + if not mask.any(): + continue + + token_mask = mask.any(dim=-1) + selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1) + if selected_idx.numel() == 0: + continue + + expert_input = x_flat.index_select(0, selected_idx) + hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input) + expert_out = self.fc3[expert_id](hidden) + + mask_selected = mask[selected_idx] + slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True) + selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1) + + out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1)) + + return out_flat.reshape(batch, seq_len, self.emb_dim) class GroupedQueryAttention(nn.Module): @@ -500,7 +513,7 @@ def load_weights_into_qwen(model, param_config, params): ) # Feedforward weights - if "num_experts" in param_config: + if param_config.get("num_experts", 0) > 0: # Load router (gating) weights block.ff.gate.weight = assign( block.ff.gate.weight, @@ -525,10 +538,6 @@ def load_weights_into_qwen(model, param_config, params): params[f"{prefix}.down_proj.weight"], f"{prefix}.down_proj.weight" ) - # After assigning weights, move the expert layers from meta to CPU - block.ff.fc1[e] = block.ff.fc1[e].to("cpu") - block.ff.fc2[e] = block.ff.fc2[e].to("cpu") - block.ff.fc3[e] = block.ff.fc3[e].to("cpu") else: block.ff.fc1.weight = assign( diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 5f72208..02d9c31 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -11,6 +11,7 @@ from llms_from_scratch.qwen3 import ( QWEN_CONFIG_06_B, Qwen3Model, Qwen3Tokenizer, + MoEFeedForward, RMSNorm, ) from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV @@ -113,6 +114,36 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input): "Expected MoEFeedForward in at least one transformer block" +@torch.inference_mode() +def test_moe_forward_matches_reference(dummy_cfg_moe): + torch.manual_seed(0) + moe = MoEFeedForward(dummy_cfg_moe) + x = torch.randn(2, 5, dummy_cfg_moe["emb_dim"]) + + scores = moe.gate(x) + topk_scores, topk_indices = torch.topk(scores, moe.num_experts_per_tok, dim=-1) + topk_probs = torch.softmax(topk_scores, dim=-1) + + expert_outputs = [] + for e in range(moe.num_experts): + hidden = torch.nn.functional.silu(moe.fc1[e](x)) * moe.fc2[e](x) + out = moe.fc3[e](hidden) + expert_outputs.append(out.unsqueeze(-2)) + expert_outputs = torch.cat(expert_outputs, dim=-2) + + gating_probs = torch.zeros_like(scores) + for i in range(moe.num_experts_per_tok): + indices = topk_indices[..., i:i+1] + prob = topk_probs[..., i:i+1] + gating_probs.scatter_(dim=-1, index=indices, src=prob) + gating_probs = gating_probs.unsqueeze(-1) + + expected = (gating_probs * expert_outputs).sum(dim=-2) + + actual = moe(x) + torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) + + @torch.inference_mode() @pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"]) def test_qwen3_kvcache_equivalence(cfg_name, request):