Optional weight tying for Qwen3 and Llama3.2 pretraining (#949)

* optional weight tying for Qwen3 and Llama3.2

* typo
This commit is contained in:
casinca
2026-01-14 16:07:04 +01:00
committed by GitHub
parent e0dbec3331
commit 9c4be478f8
7 changed files with 17 additions and 9 deletions

View File

@@ -358,7 +358,7 @@
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" cos, sin = compute_rope_params(\n",
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
" theta_base=cfg[\"rope_base\"],\n",

View File

@@ -432,7 +432,7 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",

View File

@@ -422,7 +422,7 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",

View File

@@ -388,7 +388,7 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",

View File

@@ -113,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
"metadata": {
"colab": {
@@ -329,7 +329,11 @@
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Uncomment the following code to tie weights\n",
" # self.out_head.weight = self.tok_emb.weight\n",
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
"\n",
" # Reusable utilities\n",
" cos, sin = compute_rope_params(\n",
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
" theta_base=cfg[\"rope_base\"],\n",

View File

@@ -121,7 +121,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
"metadata": {
"colab": {
@@ -332,7 +332,11 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Uncomment the following code to tie weights\n",
" # self.out_head.weight = self.tok_emb.weight\n",
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
"\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",