mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Mixture-of-Experts intro (#888)
This commit is contained in:
committed by
GitHub
parent
27b6dfab9e
commit
218221ab62
@@ -153,11 +153,11 @@
|
||||
" self.emb_dim = cfg[\"emb_dim\"]\n",
|
||||
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" for _ in range(cfg[\"num_experts\"])])\n",
|
||||
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" for _ in range(cfg[\"num_experts\"])])\n",
|
||||
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_hidden_dim\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" for _ in range(cfg[\"num_experts\"])])\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
@@ -550,7 +550,7 @@
|
||||
" \"dtype\": torch.bfloat16,\n",
|
||||
" \"num_experts\": 128,\n",
|
||||
" \"num_experts_per_tok\": 8,\n",
|
||||
" \"moe_intermediate_size\": 768,\n",
|
||||
" \"moe_hidden_dim\": 768,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -1223,7 +1223,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -153,11 +153,11 @@
|
||||
" self.emb_dim = cfg[\"emb_dim\"]\n",
|
||||
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" for _ in range(cfg[\"num_experts\"])])\n",
|
||||
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_hidden_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" for _ in range(cfg[\"num_experts\"])])\n",
|
||||
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_hidden_dim\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
" for _ in range(cfg[\"num_experts\"])])\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
@@ -492,7 +492,7 @@
|
||||
" \"dtype\": torch.bfloat16,\n",
|
||||
" \"num_experts\": 128,\n",
|
||||
" \"num_experts_per_tok\": 8,\n",
|
||||
" \"moe_intermediate_size\": 768,\n",
|
||||
" \"moe_hidden_dim\": 768,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -1140,7 +1140,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user