Improve weight tying handling (#826)

* Improve weight tying handling

* fix
This commit is contained in:
Sebastian Raschka
2025-09-14 15:46:48 -05:00
committed by GitHub
parent 1412b139f2
commit 8add26cbe9
11 changed files with 545 additions and 173 deletions

View File

@@ -80,10 +80,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"blobfile version: 3.0.0\n",
"huggingface_hub version: 0.30.1\n",
"tiktoken version: 0.9.0\n",
"torch version: 2.6.0\n"
"blobfile version: 3.1.0\n",
"huggingface_hub version: 0.34.4\n",
"tiktoken version: 0.11.0\n",
"torch version: 2.8.0\n"
]
}
],
@@ -470,50 +470,9 @@
"model = Llama3Model(LLAMA32_CONFIG)"
]
},
{
"cell_type": "markdown",
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1",
"metadata": {
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1"
},
"source": [
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of parameters: 1,498,482,688\n",
"\n",
"Total number of unique parameters: 1,235,814,400\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters())\n",
"print(f\"Total number of parameters: {total_params:,}\")\n",
"\n",
"# Account for weight tying\n",
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {
"colab": {
@@ -561,6 +520,31 @@
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "41176fb0-d58a-443a-912f-4f436564b5f8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of parameters: 1,498,482,688\n",
"\n",
"Total number of unique parameters: 1,235,814,400\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters())\n",
"print(f\"Total number of parameters: {total_params:,}\")\n",
"\n",
"# Account for weight tying\n",
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
@@ -718,16 +702,7 @@
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
"outputId": "e6e6dc05-7330-45bc-a9a7-331919155bdd"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"# Uncomment and run the following code if you are executing the notebook for the first time\n",
"\n",
@@ -807,11 +782,14 @@
"def assign(left, right, tensor_name=\"unknown\"):\n",
" if left.shape != right.shape:\n",
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
" \n",
" with torch.no_grad():\n",
" if isinstance(right, torch.Tensor):\n",
" left.copy_(right)\n",
" else:\n",
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
"\n",
" if isinstance(right, torch.Tensor):\n",
" return torch.nn.Parameter(right.clone().detach())\n",
" else:\n",
" return torch.nn.Parameter(torch.tensor(right))\n",
" return left \n",
"\n",
"\n",
"def load_weights_into_llama(model, param_config, params):\n",
@@ -874,7 +852,7 @@
" if \"lm_head.weight\" in params.keys():\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
" else:\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
" model.out_head.weight = model.tok_emb.weight\n",
" print(\"Model uses weight tying.\")"
]
},
@@ -945,6 +923,42 @@
{
"cell_type": "code",
"execution_count": 19,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of unique parameters: 1,235,814,400\n"
]
}
],
"source": [
"def count_unique_parameters(model):\n",
" unique_params = set()\n",
" total_unique_params = 0\n",
" \n",
" for param in model.parameters():\n",
" if param.data_ptr() not in unique_params:\n",
" total_unique_params += param.numel()\n",
" unique_params.add(param.data_ptr())\n",
" \n",
" return total_unique_params\n",
"\n",
"total_params_uniq = count_unique_parameters(model)\n",
"print(f\"Total number of unique parameters: {total_params_uniq:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
"metadata": {
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
@@ -959,9 +973,29 @@
}
],
"source": [
"# Checks that the weight values are the same\n",
"print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "1ec6977d-ec42-42b5-bca2-3ecda791ea66",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Weight tying: True\n"
]
}
],
"source": [
"# Furthermore, check if PyTorch uses the same underlying memory\n",
"print(\"Weight tying:\", model.tok_emb.weight.data_ptr() == model.out_head.weight.data_ptr())"
]
},
{
"cell_type": "markdown",
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
@@ -975,7 +1009,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 22,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1034,7 +1068,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@@ -1044,7 +1078,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 18.20 sec\n",
"Time: 13.21 sec\n",
"\n",
"\n",
"Output text:\n",