Mixture-of-Experts intro (#888)

This commit is contained in:
Sebastian Raschka
2025-10-19 22:17:59 -05:00
committed by GitHub
parent 27b6dfab9e
commit 218221ab62
13 changed files with 1333 additions and 228 deletions

View File

@@ -153,11 +153,11 @@
" self.emb_dim = cfg[\"emb_dim\"]\n",
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_hidden_dim\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
"\n",
" def forward(self, x):\n",
@@ -550,7 +550,7 @@
" \"dtype\": torch.bfloat16,\n",
" \"num_experts\": 128,\n",
" \"num_experts_per_tok\": 8,\n",
" \"moe_intermediate_size\": 768,\n",
" \"moe_hidden_dim\": 768,\n",
"}"
]
},
@@ -1223,7 +1223,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -153,11 +153,11 @@
" self.emb_dim = cfg[\"emb_dim\"]\n",
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_hidden_dim\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
" for _ in range(cfg[\"num_experts\"])])\n",
"\n",
" def forward(self, x):\n",
@@ -492,7 +492,7 @@
" \"dtype\": torch.bfloat16,\n",
" \"num_experts\": 128,\n",
" \"num_experts_per_tok\": 8,\n",
" \"moe_intermediate_size\": 768,\n",
" \"moe_hidden_dim\": 768,\n",
"}"
]
},
@@ -1140,7 +1140,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,