Improve weight tying handling (#826)

* Improve weight tying handling

* fix
This commit is contained in:
Sebastian Raschka
2025-09-14 15:46:48 -05:00
committed by GitHub
parent 1412b139f2
commit 8add26cbe9
11 changed files with 545 additions and 173 deletions

View File

@@ -427,7 +427,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
"metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
@@ -842,10 +842,15 @@
"\n",
" def assign(left, right, tensor_name=\"unknown\"):\n",
" if left.shape != right.shape:\n",
" raise ValueError(\n",
" f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n",
" )\n",
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
"\n",
" with torch.no_grad():\n",
" if isinstance(right, torch.Tensor):\n",
" left.copy_(right)\n",
" else:\n",
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
"\n",
" return left\n",
"\n",
" # Embedding weights\n",
" if \"model.embed_tokens.weight\" in params:\n",
@@ -948,13 +953,9 @@
" params[\"lm_head.weight\"],\n",
" \"lm_head.weight\",\n",
" )\n",
" elif \"model.embed_tokens.weight\" in params:\n",
" # Weight tying: reuse the embedding weights\n",
" model.out_head.weight = assign(\n",
" model.out_head.weight,\n",
" params[\"model.embed_tokens.weight\"],\n",
" \"model.embed_tokens.weight\",\n",
" )"
" else:\n",
" model.out_head.weight = model.tok_emb.weight\n",
" print(\"Model uses weight tying.\")"
]
},
{
@@ -1036,7 +1037,15 @@
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model uses weight tying.\n"
]
}
],
"source": [
"import json\n",
"import os\n",
@@ -1175,7 +1184,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"id": "988f55e2-0f60-4bd8-ae55-db116ff2b26d",
"metadata": {},
"outputs": [],
@@ -1186,7 +1195,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1213,7 +1222,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"id": "56c9d0cf-25e9-4375-8d5c-368fa6911fdf",
"metadata": {},
"outputs": [

View File

@@ -235,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 7,
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
"metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
@@ -320,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 8,
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
"metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
@@ -386,7 +386,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
"metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
@@ -507,7 +507,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 10,
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
@@ -554,7 +554,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 11,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
@@ -567,7 +567,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 12,
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
"metadata": {},
"outputs": [
@@ -602,7 +602,7 @@
")"
]
},
"execution_count": 19,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -621,7 +621,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 13,
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
"metadata": {},
"outputs": [
@@ -634,7 +634,7 @@
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 20,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -645,7 +645,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 14,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
@@ -676,7 +676,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 15,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {
"colab": {
@@ -726,7 +726,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 16,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
@@ -756,7 +756,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 17,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
@@ -767,10 +767,15 @@
"\n",
" def assign(left, right, tensor_name=\"unknown\"):\n",
" if left.shape != right.shape:\n",
" raise ValueError(\n",
" f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n",
" )\n",
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
"\n",
" with torch.no_grad():\n",
" if isinstance(right, torch.Tensor):\n",
" left.copy_(right)\n",
" else:\n",
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
"\n",
" return left\n",
"\n",
" # Embedding weights\n",
" if \"model.embed_tokens.weight\" in params:\n",
@@ -873,13 +878,9 @@
" params[\"lm_head.weight\"],\n",
" \"lm_head.weight\",\n",
" )\n",
" elif \"model.embed_tokens.weight\" in params:\n",
" # Weight tying: reuse the embedding weights\n",
" model.out_head.weight = assign(\n",
" model.out_head.weight,\n",
" params[\"model.embed_tokens.weight\"],\n",
" \"model.embed_tokens.weight\",\n",
" )"
" else:\n",
" model.out_head.weight = model.tok_emb.weight\n",
" print(\"Model uses weight tying.\")"
]
},
{
@@ -900,7 +901,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 18,
"id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8",
"metadata": {},
"outputs": [],
@@ -913,7 +914,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 19,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {
"colab": {
@@ -936,7 +937,15 @@
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model uses weight tying.\n"
]
}
],
"source": [
"import json\n",
"import os\n",
@@ -989,7 +998,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 20,
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
"metadata": {},
"outputs": [],
@@ -1019,7 +1028,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 21,
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
"metadata": {},
"outputs": [],
@@ -1037,7 +1046,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 22,
"id": "1946b534-e3af-431a-a222-391a60bfa892",
"metadata": {},
"outputs": [
@@ -1047,7 +1056,7 @@
"'<bos><start_of_turn>user\\nGive me a short introduction to large language models.<end_of_turn>\\n<start_of_turn>model\\n'"
]
},
"execution_count": 29,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -1075,7 +1084,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"id": "a6250333-9cf0-4f36-8e28-76be2eac1c43",
"metadata": {},
"outputs": [],
@@ -1086,7 +1095,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 24,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1112,7 +1121,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 25,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@@ -1176,7 +1185,7 @@
"provenance": []
},
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -1190,7 +1199,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.10.16"
}
},
"nbformat": 4,