mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Improve weight tying handling (#826)
* Improve weight tying handling * fix
This commit is contained in:
committed by
GitHub
parent
1412b139f2
commit
8add26cbe9
@@ -235,7 +235,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 7,
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
|
||||
"metadata": {
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
|
||||
@@ -320,7 +320,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 8,
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
|
||||
"metadata": {
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
|
||||
@@ -386,7 +386,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
||||
"metadata": {
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
||||
@@ -507,7 +507,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 10,
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
|
||||
"metadata": {
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
|
||||
@@ -554,7 +554,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 11,
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
||||
"metadata": {
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
|
||||
@@ -567,7 +567,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 12,
|
||||
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -602,7 +602,7 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -621,7 +621,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 13,
|
||||
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -634,7 +634,7 @@
|
||||
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -645,7 +645,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 14,
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -676,7 +676,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 15,
|
||||
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -726,7 +726,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 16,
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
||||
"metadata": {
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
|
||||
@@ -756,7 +756,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 17,
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e",
|
||||
"metadata": {
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e"
|
||||
@@ -767,10 +767,15 @@
|
||||
"\n",
|
||||
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(\n",
|
||||
" f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n",
|
||||
" )\n",
|
||||
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
"\n",
|
||||
" return left\n",
|
||||
"\n",
|
||||
" # Embedding weights\n",
|
||||
" if \"model.embed_tokens.weight\" in params:\n",
|
||||
@@ -873,13 +878,9 @@
|
||||
" params[\"lm_head.weight\"],\n",
|
||||
" \"lm_head.weight\",\n",
|
||||
" )\n",
|
||||
" elif \"model.embed_tokens.weight\" in params:\n",
|
||||
" # Weight tying: reuse the embedding weights\n",
|
||||
" model.out_head.weight = assign(\n",
|
||||
" model.out_head.weight,\n",
|
||||
" params[\"model.embed_tokens.weight\"],\n",
|
||||
" \"model.embed_tokens.weight\",\n",
|
||||
" )"
|
||||
" else:\n",
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -900,7 +901,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 18,
|
||||
"id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -913,7 +914,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 19,
|
||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -936,7 +937,15 @@
|
||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model uses weight tying.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
@@ -989,7 +998,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 20,
|
||||
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1019,7 +1028,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 21,
|
||||
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1037,7 +1046,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 22,
|
||||
"id": "1946b534-e3af-431a-a222-391a60bfa892",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1047,7 +1056,7 @@
|
||||
"'<bos><start_of_turn>user\\nGive me a short introduction to large language models.<end_of_turn>\\n<start_of_turn>model\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -1075,7 +1084,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 23,
|
||||
"id": "a6250333-9cf0-4f36-8e28-76be2eac1c43",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1086,7 +1095,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 24,
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||
"metadata": {
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
||||
@@ -1112,7 +1121,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 25,
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||
"metadata": {
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
||||
@@ -1176,7 +1185,7 @@
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -1190,7 +1199,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.6"
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user