mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Optional weight tying for Qwen3 and Llama3.2 pretraining (#949)
* optional weight tying for Qwen3 and Llama3.2 * typo
This commit is contained in:
@@ -358,7 +358,7 @@
|
||||
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" # Reusable utilities\n",
|
||||
" cos, sin = compute_rope_params(\n",
|
||||
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
||||
" theta_base=cfg[\"rope_base\"],\n",
|
||||
|
||||
@@ -432,7 +432,7 @@
|
||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" # Reusable utilities\n",
|
||||
" if cfg[\"head_dim\"] is None:\n",
|
||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||
" else:\n",
|
||||
|
||||
@@ -422,7 +422,7 @@
|
||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" # Reusable utilities\n",
|
||||
" if cfg[\"head_dim\"] is None:\n",
|
||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||
" else:\n",
|
||||
|
||||
@@ -388,7 +388,7 @@
|
||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" # Reusable utilities\n",
|
||||
" if cfg[\"head_dim\"] is None:\n",
|
||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||
" else:\n",
|
||||
|
||||
@@ -113,7 +113,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -329,7 +329,11 @@
|
||||
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" # Uncomment the following code to tie weights\n",
|
||||
" # self.out_head.weight = self.tok_emb.weight\n",
|
||||
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
|
||||
"\n",
|
||||
" # Reusable utilities\n",
|
||||
" cos, sin = compute_rope_params(\n",
|
||||
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
||||
" theta_base=cfg[\"rope_base\"],\n",
|
||||
|
||||
@@ -121,7 +121,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -332,7 +332,11 @@
|
||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" # Uncomment the following code to tie weights\n",
|
||||
" # self.out_head.weight = self.tok_emb.weight\n",
|
||||
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
|
||||
"\n",
|
||||
" # Reusable utilities\n",
|
||||
" if cfg[\"head_dim\"] is None:\n",
|
||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||
" else:\n",
|
||||
|
||||
Reference in New Issue
Block a user