From 9c4be478f89f95414adac173ec034dd777e80974 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Wed, 14 Jan 2026 16:07:04 +0100 Subject: [PATCH] Optional weight tying for Qwen3 and Llama3.2 pretraining (#949) * optional weight tying for Qwen3 and Llama3.2 * typo --- ch05/07_gpt_to_llama/standalone-llama32.ipynb | 2 +- ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb | 2 +- ch05/11_qwen3/standalone-qwen3-moe.ipynb | 2 +- ch05/11_qwen3/standalone-qwen3.ipynb | 2 +- ch05/14_ch05_with_other_llms/ch05-llama32.ipynb | 8 ++++++-- ch05/14_ch05_with_other_llms/ch05-qwen3.ipynb | 8 ++++++-- pkg/llms_from_scratch/kv_cache/llama3.py | 2 +- 7 files changed, 17 insertions(+), 9 deletions(-) diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index db13d63..ab68fff 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -358,7 +358,7 @@ " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # Reusuable utilities\n", + " # Reusable utilities\n", " cos, sin = compute_rope_params(\n", " head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n", " theta_base=cfg[\"rope_base\"],\n", diff --git a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb index 871a085..c01f5c1 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb @@ -432,7 +432,7 @@ " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # Reusuable utilities\n", + " # Reusable utilities\n", " if cfg[\"head_dim\"] is None:\n", " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n", " else:\n", diff --git a/ch05/11_qwen3/standalone-qwen3-moe.ipynb b/ch05/11_qwen3/standalone-qwen3-moe.ipynb index e9a8f22..de10c1a 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe.ipynb @@ -422,7 +422,7 @@ " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # Reusuable utilities\n", + " # Reusable utilities\n", " if cfg[\"head_dim\"] is None:\n", " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n", " else:\n", diff --git a/ch05/11_qwen3/standalone-qwen3.ipynb b/ch05/11_qwen3/standalone-qwen3.ipynb index 156854f..1e22787 100644 --- a/ch05/11_qwen3/standalone-qwen3.ipynb +++ b/ch05/11_qwen3/standalone-qwen3.ipynb @@ -388,7 +388,7 @@ " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # Reusuable utilities\n", + " # Reusable utilities\n", " if cfg[\"head_dim\"] is None:\n", " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n", " else:\n", diff --git a/ch05/14_ch05_with_other_llms/ch05-llama32.ipynb b/ch05/14_ch05_with_other_llms/ch05-llama32.ipynb index 4f9dac5..65ae73c 100644 --- a/ch05/14_ch05_with_other_llms/ch05-llama32.ipynb +++ b/ch05/14_ch05_with_other_llms/ch05-llama32.ipynb @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "86000d74-624a-48f0-86da-f41926cb9e04", "metadata": { "colab": { @@ -329,7 +329,11 @@ " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # Reusuable utilities\n", + " # Uncomment the following code to tie weights\n", + " # self.out_head.weight = self.tok_emb.weight\n", + " # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n", + "\n", + " # Reusable utilities\n", " cos, sin = compute_rope_params(\n", " head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n", " theta_base=cfg[\"rope_base\"],\n", diff --git a/ch05/14_ch05_with_other_llms/ch05-qwen3.ipynb b/ch05/14_ch05_with_other_llms/ch05-qwen3.ipynb index 1218aee..ea09408 100644 --- a/ch05/14_ch05_with_other_llms/ch05-qwen3.ipynb +++ b/ch05/14_ch05_with_other_llms/ch05-qwen3.ipynb @@ -121,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "86000d74-624a-48f0-86da-f41926cb9e04", "metadata": { "colab": { @@ -332,7 +332,11 @@ " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", - " # Reusuable utilities\n", + " # Uncomment the following code to tie weights\n", + " # self.out_head.weight = self.tok_emb.weight\n", + " # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n", + "\n", + " # Reusable utilities\n", " if cfg[\"head_dim\"] is None:\n", " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n", " else:\n", diff --git a/pkg/llms_from_scratch/kv_cache/llama3.py b/pkg/llms_from_scratch/kv_cache/llama3.py index 098dee4..dc4adfe 100644 --- a/pkg/llms_from_scratch/kv_cache/llama3.py +++ b/pkg/llms_from_scratch/kv_cache/llama3.py @@ -65,7 +65,7 @@ class Llama3Model(nn.Module): self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) - # Reusuable utilities + # Reusable utilities cos, sin = compute_rope_params( head_dim=cfg["emb_dim"] // cfg["n_heads"], theta_base=cfg["rope_base"],