From 8add26cbe9b810b02ce0b99c3dd848f34e6026f4 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sun, 14 Sep 2025 15:46:48 -0500 Subject: [PATCH] Improve weight tying handling (#826) * Improve weight tying handling * fix --- .../converting-gpt-to-llama2.ipynb | 15 +- .../converting-llama2-to-llama3.ipynb | 337 ++++++++++++++++-- ch05/07_gpt_to_llama/standalone-llama32.ipynb | 160 +++++---- .../standalone-qwen3-moe-plus-kvcache.ipynb | 14 +- ch05/11_qwen3/standalone-qwen3-moe.ipynb | 14 +- .../standalone-qwen3-plus-kvcache.ipynb | 14 +- ch05/11_qwen3/standalone-qwen3.ipynb | 14 +- .../standalone-gemma3-plus-kvcache.ipynb | 41 ++- ch05/12_gemma3/standalone-gemma3.ipynb | 81 +++-- pkg/llms_from_scratch/llama3.py | 14 +- pkg/llms_from_scratch/qwen3.py | 14 +- 11 files changed, 545 insertions(+), 173 deletions(-) diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb index fa5af69..0e8bc04 100644 --- a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -1400,14 +1400,17 @@ }, "outputs": [], "source": [ - "def assign(left, right):\n", + "def assign(left, right, tensor_name=\"unknown\"):\n", " if left.shape != right.shape:\n", - " raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n", + " raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n", + " \n", + " with torch.no_grad():\n", + " if isinstance(right, torch.Tensor):\n", + " left.copy_(right)\n", + " else:\n", + " left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n", "\n", - " if isinstance(right, torch.Tensor):\n", - " return torch.nn.Parameter(right.clone().detach())\n", - " else:\n", - " return torch.nn.Parameter(torch.tensor(right))\n", + " return left \n", "\n", "\n", "def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n", diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 8ee68bc..bdd1065 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -94,10 +94,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "blobfile version: 3.0.0\n", - "huggingface_hub version: 0.30.1\n", - "tiktoken version: 0.9.0\n", - "torch version: 2.6.0\n" + "blobfile version: 3.1.0\n", + "huggingface_hub version: 0.34.4\n", + "tiktoken version: 0.11.0\n", + "torch version: 2.8.0\n" ] } ], @@ -1184,16 +1184,7 @@ "id": "3357a230-b678-4691-a238-257ee4e80185", "outputId": "a3652def-ea7f-46fb-f293-2a59affb71a0" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "from huggingface_hub import login\n", "import json\n", @@ -1226,7 +1217,22 @@ "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", "outputId": "c9836ba8-5176-4dd5-b618-6cc36fdbe1f0" }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "685326b4fd014ff689e928f4200f5182", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "original/tokenizer.model: 0%| | 0.00/2.18M [00:00)" ] }, - "execution_count": 20, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -645,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 14, "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", "metadata": { "colab": { @@ -676,7 +676,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 15, "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", "metadata": { "colab": { @@ -726,7 +726,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 16, "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", "metadata": { "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" @@ -756,7 +756,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 17, "id": "75166128-5899-4995-9b88-9672e135650e", "metadata": { "id": "75166128-5899-4995-9b88-9672e135650e" @@ -767,10 +767,15 @@ "\n", " def assign(left, right, tensor_name=\"unknown\"):\n", " if left.shape != right.shape:\n", - " raise ValueError(\n", - " f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n", - " )\n", - " return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n", + " raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n", + "\n", + " with torch.no_grad():\n", + " if isinstance(right, torch.Tensor):\n", + " left.copy_(right)\n", + " else:\n", + " left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n", + "\n", + " return left\n", "\n", " # Embedding weights\n", " if \"model.embed_tokens.weight\" in params:\n", @@ -873,13 +878,9 @@ " params[\"lm_head.weight\"],\n", " \"lm_head.weight\",\n", " )\n", - " elif \"model.embed_tokens.weight\" in params:\n", - " # Weight tying: reuse the embedding weights\n", - " model.out_head.weight = assign(\n", - " model.out_head.weight,\n", - " params[\"model.embed_tokens.weight\"],\n", - " \"model.embed_tokens.weight\",\n", - " )" + " else:\n", + " model.out_head.weight = model.tok_emb.weight\n", + " print(\"Model uses weight tying.\")" ] }, { @@ -900,7 +901,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 18, "id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8", "metadata": {}, "outputs": [], @@ -913,7 +914,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 19, "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "metadata": { "colab": { @@ -936,7 +937,15 @@ "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model uses weight tying.\n" + ] + } + ], "source": [ "import json\n", "import os\n", @@ -989,7 +998,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 20, "id": "b68ab489-48e5-471e-a814-56cda2d60f81", "metadata": {}, "outputs": [], @@ -1019,7 +1028,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 21, "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", "metadata": {}, "outputs": [], @@ -1037,7 +1046,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 22, "id": "1946b534-e3af-431a-a222-391a60bfa892", "metadata": {}, "outputs": [ @@ -1047,7 +1056,7 @@ "'user\\nGive me a short introduction to large language models.\\nmodel\\n'" ] }, - "execution_count": 29, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -1075,7 +1084,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "a6250333-9cf0-4f36-8e28-76be2eac1c43", "metadata": {}, "outputs": [], @@ -1086,7 +1095,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 24, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": { "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" @@ -1112,7 +1121,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 25, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": { "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" @@ -1176,7 +1185,7 @@ "provenance": [] }, "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -1190,7 +1199,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index 0580ee2..0bfb469 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -503,13 +503,17 @@ def assign(left, right, tensor_name="unknown"): if left.shape != right.shape: raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}") - if isinstance(right, torch.Tensor): - return torch.nn.Parameter(right.clone().detach()) - else: - return torch.nn.Parameter(torch.tensor(right)) + with torch.no_grad(): + if isinstance(right, torch.Tensor): + left.copy_(right) + else: + left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device)) + + return left def load_weights_into_llama(model, param_config, params): + model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") for l in range(param_config["n_layers"]): @@ -569,5 +573,5 @@ def load_weights_into_llama(model, param_config, params): if "lm_head.weight" in params.keys(): model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight") else: - model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") + model.out_head.weight = model.tok_emb.weight print("Model uses weight tying.") diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py index 09d4e0a..a68b324 100644 --- a/pkg/llms_from_scratch/qwen3.py +++ b/pkg/llms_from_scratch/qwen3.py @@ -64,7 +64,7 @@ QWEN3_CONFIG_8B = { "context_length": 40_960, "emb_dim": 4096, # 60% larger than above "n_heads": 32, - "n_layers": 36, + "n_layers": 36, "hidden_dim": 12288, # 26% larger than above "head_dim": 128, "qk_norm": True, @@ -387,7 +387,14 @@ def load_weights_into_qwen(model, param_config, params): def assign(left, right, tensor_name="unknown"): if left.shape != right.shape: raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}") - return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right)) + + with torch.no_grad(): + if isinstance(right, torch.Tensor): + left.copy_(right) + else: + left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device)) + + return left model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") @@ -500,9 +507,8 @@ def load_weights_into_qwen(model, param_config, params): if "lm_head.weight" in params: model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight") else: - # Model uses weight tying, hence we reuse the embedding layer weights here + model.out_head.weight = model.tok_emb.weight print("Model uses weight tying.") - model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") class Qwen3Tokenizer: