Improve MoE implementation (#841)

This commit is contained in:
Sebastian Raschka
2025-09-22 15:21:06 -05:00
committed by GitHub
parent 20041fb94b
commit e742d8af2c
6 changed files with 177 additions and 250 deletions

View File

@@ -150,79 +150,52 @@
" super().__init__()\n",
" self.num_experts_per_tok = cfg[\"num_experts_per_tok\"]\n",
" self.num_experts = cfg[\"num_experts\"]\n",
" self.emb_dim = cfg[\"emb_dim\"]\n",
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # meta device to reduce memory pressure when initializing the model before loading weights\n",
" meta_device = torch.device(\"meta\")\n",
" self.fc1 = nn.ModuleList([\n",
" nn.Linear(\n",
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
" bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
" for _ in range(cfg[\"num_experts\"])]\n",
" )\n",
" self.fc2 = nn.ModuleList([\n",
" nn.Linear(\n",
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
" )\n",
" for _ in range(cfg[\"num_experts\"])]\n",
" )\n",
" self.fc3 = nn.ModuleList([\n",
" nn.Linear(\n",
" cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n",
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
" )\n",
" for _ in range(cfg[\"num_experts\"])]\n",
" )\n",
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
"\n",
" def forward(self, x):\n",
" b, seq_len, embed_dim = x.shape\n",
" scores = self.gate(x) # (b, seq_len, num_experts)\n",
" topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
" topk_probs = torch.softmax(topk_scores, dim=-1)\n",
" \n",
" expert_outputs = []\n",
" for e in range(self.num_experts):\n",
" hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)\n",
" out = self.fc3[e](hidden)\n",
" expert_outputs.append(out.unsqueeze(-2))\n",
" expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)\n",
"\n",
" gating_probs = torch.zeros_like(scores)\n",
" batch, seq_len, _ = x.shape\n",
" x_flat = x.reshape(batch * seq_len, -1)\n",
" out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)\n",
"\n",
" for i in range(self.num_experts_per_tok):\n",
" indices = topk_indices[..., i:i+1]\n",
" prob = topk_probs[..., i:i+1]\n",
" gating_probs.scatter_(dim=-1, index=indices, src=prob)\n",
" gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)\n",
" \n",
" # Weighted sum over experts\n",
" y = (gating_probs * expert_outputs).sum(dim=-2)\n",
" return y\n",
" topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)\n",
" topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)\n",
"\n",
" unique_experts = torch.unique(topk_indices_flat)\n",
"\n",
" # For some reason, the version below is slower than the naive version\n",
" # above that computes all experts, even the unused ones\n",
" for expert_id_tensor in unique_experts:\n",
" expert_id = int(expert_id_tensor.item())\n",
" mask = topk_indices_flat == expert_id\n",
" if not mask.any():\n",
" continue\n",
"\n",
" # def forward(self, x):\n",
" # scores = self.gate(x) # (b, seq_len, num_experts)\n",
" # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
" # topk_probs = torch.softmax(topk_scores, dim=-1)\n",
" # y = torch.zeros_like(x)\n",
" #\n",
" # for i in range(self.num_experts_per_tok):\n",
" # # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
" # expert_indices = topk_indices[..., i]\n",
" # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n",
" #\n",
" # # For each expert, process only the tokens assigned to it\n",
" # for e in range(self.num_experts):\n",
" # mask = (expert_indices == e) # (b, seq_len) boolean mask\n",
" # if mask.any():\n",
" # selected = x[mask] # (num_tokens_e, emb_dim)\n",
" # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
" # y[mask] += prob[mask] * out\n",
" # return y"
" token_mask = mask.any(dim=-1)\n",
" selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)\n",
" if selected_idx.numel() == 0:\n",
" continue\n",
"\n",
" expert_input = x_flat.index_select(0, selected_idx)\n",
" hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)\n",
" expert_out = self.fc3[expert_id](hidden)\n",
"\n",
" mask_selected = mask[selected_idx]\n",
" slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)\n",
" selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)\n",
"\n",
" out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))\n",
"\n",
" return out_flat.reshape(batch, seq_len, self.emb_dim)"
]
},
{
@@ -829,7 +802,7 @@
" )\n",
"\n",
" # Feedforward weights\n",
" if \"num_experts\" in param_config:\n",
" if \"num_experts\" in param_config and param_config[\"num_experts\"] > 0:\n",
" # Load router (gating) weights\n",
" block.ff.gate.weight = assign(\n",
" block.ff.gate.weight,\n",
@@ -854,10 +827,6 @@
" params[f\"{prefix}.down_proj.weight\"],\n",
" f\"{prefix}.down_proj.weight\"\n",
" )\n",
" # After assigning weights, move the expert layers from meta to CPU\n",
" block.ff.fc1[e] = block.ff.fc1[e].to(\"cpu\")\n",
" block.ff.fc2[e] = block.ff.fc2[e].to(\"cpu\")\n",
" block.ff.fc3[e] = block.ff.fc3[e].to(\"cpu\")\n",
"\n",
" else:\n",
" block.ff.fc1.weight = assign(\n",

View File

@@ -89,8 +89,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.34.3\n",
"tokenizers version: 0.21.4\n",
"huggingface_hub version: 0.35.0\n",
"tokenizers version: 0.22.1\n",
"torch version: 2.7.1+cu128\n"
]
}
@@ -150,79 +150,52 @@
" super().__init__()\n",
" self.num_experts_per_tok = cfg[\"num_experts_per_tok\"]\n",
" self.num_experts = cfg[\"num_experts\"]\n",
" self.emb_dim = cfg[\"emb_dim\"]\n",
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # meta device to reduce memory pressure when initializing the model before loading weights\n",
" meta_device = torch.device(\"meta\")\n",
" self.fc1 = nn.ModuleList([\n",
" nn.Linear(\n",
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
" bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
" for _ in range(cfg[\"num_experts\"])]\n",
" )\n",
" self.fc2 = nn.ModuleList([\n",
" nn.Linear(\n",
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
" )\n",
" for _ in range(cfg[\"num_experts\"])]\n",
" )\n",
" self.fc3 = nn.ModuleList([\n",
" nn.Linear(\n",
" cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n",
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
" )\n",
" for _ in range(cfg[\"num_experts\"])]\n",
" )\n",
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
"\n",
" def forward(self, x):\n",
" b, seq_len, embed_dim = x.shape\n",
" scores = self.gate(x) # (b, seq_len, num_experts)\n",
" topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
" topk_probs = torch.softmax(topk_scores, dim=-1)\n",
" \n",
" expert_outputs = []\n",
" for e in range(self.num_experts):\n",
" hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)\n",
" out = self.fc3[e](hidden)\n",
" expert_outputs.append(out.unsqueeze(-2))\n",
" expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)\n",
"\n",
" gating_probs = torch.zeros_like(scores)\n",
" batch, seq_len, _ = x.shape\n",
" x_flat = x.reshape(batch * seq_len, -1)\n",
" out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)\n",
"\n",
" for i in range(self.num_experts_per_tok):\n",
" indices = topk_indices[..., i:i+1]\n",
" prob = topk_probs[..., i:i+1]\n",
" gating_probs.scatter_(dim=-1, index=indices, src=prob)\n",
" gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)\n",
" \n",
" # Weighted sum over experts\n",
" y = (gating_probs * expert_outputs).sum(dim=-2)\n",
" return y\n",
" topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)\n",
" topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)\n",
"\n",
" unique_experts = torch.unique(topk_indices_flat)\n",
"\n",
" # For some reason, the version below is slower than the naive version\n",
" # above that computes all experts, even the unused ones\n",
" for expert_id_tensor in unique_experts:\n",
" expert_id = int(expert_id_tensor.item())\n",
" mask = topk_indices_flat == expert_id\n",
" if not mask.any():\n",
" continue\n",
"\n",
" # def forward(self, x):\n",
" # scores = self.gate(x) # (b, seq_len, num_experts)\n",
" # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
" # topk_probs = torch.softmax(topk_scores, dim=-1)\n",
" # y = torch.zeros_like(x)\n",
" #\n",
" # for i in range(self.num_experts_per_tok):\n",
" # # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
" # expert_indices = topk_indices[..., i]\n",
" # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n",
" #\n",
" # # For each expert, process only the tokens assigned to it\n",
" # for e in range(self.num_experts):\n",
" # mask = (expert_indices == e) # (b, seq_len) boolean mask\n",
" # if mask.any():\n",
" # selected = x[mask] # (num_tokens_e, emb_dim)\n",
" # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
" # y[mask] += prob[mask] * out\n",
" # return y"
" token_mask = mask.any(dim=-1)\n",
" selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)\n",
" if selected_idx.numel() == 0:\n",
" continue\n",
"\n",
" expert_input = x_flat.index_select(0, selected_idx)\n",
" hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)\n",
" expert_out = self.fc3[expert_id](hidden)\n",
"\n",
" mask_selected = mask[selected_idx]\n",
" slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)\n",
" selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)\n",
"\n",
" out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))\n",
"\n",
" return out_flat.reshape(batch, seq_len, self.emb_dim)"
]
},
{
@@ -582,10 +555,10 @@
{
"data": {
"text/plain": [
"tensor([[[nan, nan, nan, ..., nan, nan, nan],\n",
" [nan, nan, nan, ..., nan, nan, nan],\n",
" [nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
"tensor([[[ 0.3223, -0.0562, 0.2490, ..., 0.4551, -0.0542, 0.8242],\n",
" [ 0.0688, 0.0786, -0.0312, ..., 0.6406, -0.9141, 0.8672],\n",
" [-0.6172, 0.4121, 0.3750, ..., 0.1699, -0.2500, 0.6953]]],\n",
" device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 12,
@@ -771,7 +744,7 @@
" )\n",
"\n",
" # Feedforward weights\n",
" if \"num_experts\" in param_config:\n",
" if \"num_experts\" in param_config and param_config[\"num_experts\"] > 0:\n",
" # Load router (gating) weights\n",
" block.ff.gate.weight = assign(\n",
" block.ff.gate.weight,\n",
@@ -796,10 +769,6 @@
" params[f\"{prefix}.down_proj.weight\"],\n",
" f\"{prefix}.down_proj.weight\"\n",
" )\n",
" # After assigning weights, move the expert layers from meta to CPU\n",
" block.ff.fc1[e] = block.ff.fc1[e].to(\"cpu\")\n",
" block.ff.fc2[e] = block.ff.fc2[e].to(\"cpu\")\n",
" block.ff.fc3[e] = block.ff.fc3[e].to(\"cpu\")\n",
"\n",
" else:\n",
" block.ff.fc1.weight = assign(\n",
@@ -863,12 +832,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "488c832145db4dd4848aa67d54a33f0d",
"model_id": "acf19bb84d754884821e1794cedb25a4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 27 files: 0%| | 0/27 [00:00<?, ?it/s]"
"Fetching 28 files: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
@@ -1099,76 +1068,11 @@
" \n",
" Returns:\n",
" int: Index of target if found, -1 if not found\n",
" \n",
" Time Complexity: O(log n)\n",
" Space Complexity: O(1)\n",
" \"\"\"\n",
" left = 0\n",
" right = len(arr) - 1\n",
" \n",
" while left <= right:\n",
" # Calculate middle index (avoiding potential overflow)\n",
" mid = left + (right - left) // 2\n",
" \n",
" if arr[mid] == target:\n",
" return mid\n",
" elif arr[mid] < target:\n",
" left = mid + 1\n",
" else:\n",
" right = mid - 1\n",
" \n",
" return -1 # Target not found\n",
"```\n",
"\n",
"## Recursive Binary Search\n",
"\n",
"```python\n",
"def binary_search_recursive(arr, target, left=0, right=None):\n",
" \"\"\"\n",
" Recursive binary search implementation\n",
" \n",
" Args:\n",
" arr: Sorted list of elements\n",
" target: Element to search for\n",
" left: Left boundary (default: 0)\n",
" right: Right boundary (default: len(arr) - 1)\n",
" \n",
" Returns:\n",
" int: Index of target if found, -1 if not found\n",
" \n",
" Time Complexity: O(log n)\n",
" Space Complexity: O(log n) due to recursion stack\n",
" \"\"\"\n",
" if right is None:\n",
" right = len(arr) - 1\n",
" \n",
" # Base case: element not found\n",
" if left > right:\n",
" return -1\n",
" \n",
" # Calculate middle index\n",
" mid = left + (right - left) // 2\n",
" \n",
" if arr[mid] == target:\n",
" return mid\n",
" elif arr[mid] < target:\n",
" return binary_search_recursive(arr, target, mid + 1, right)\n",
" else:\n",
" return binary_search_recursive(arr, target, left, mid - 1)\n",
"```\n",
"\n",
"## Enhanced Version with Additional Features\n",
"\n",
"```python\n",
"def binary_search_enhanced(arr, target, find_first=True):\n",
" \"\"\"\n",
" Enhanced binary search that can find first or last occurrence\n",
" of a target in case of duplicates\n",
" \n",
" Args:\n",
" arr: Sorted list of elements\n",
" target: Element to search for\n",
" find_first: If True, find"
" while left"
]
}
],
@@ -1179,7 +1083,7 @@
"for token in generate_text_basic_stream(\n",
" model=model,\n",
" token_ids=input_token_ids_tensor,\n",
" max_new_tokens=500,\n",
" max_new_tokens=100, # Cut-off after 100 tokens because non-kv variant is very slow\n",
" # eos_token_id=tokenizer.eos_token_id\n",
"):\n",
" token_id = token.squeeze(0).tolist()\n",