mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Improve weight tying handling (#826)
* Improve weight tying handling * fix
This commit is contained in:
committed by
GitHub
parent
1412b139f2
commit
8add26cbe9
@@ -1400,14 +1400,17 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def assign(left, right):\n",
|
||||
"def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
"\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" return torch.nn.Parameter(right.clone().detach())\n",
|
||||
" else:\n",
|
||||
" return torch.nn.Parameter(torch.tensor(right))\n",
|
||||
" return left \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",
|
||||
|
||||
@@ -94,10 +94,10 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"blobfile version: 3.0.0\n",
|
||||
"huggingface_hub version: 0.30.1\n",
|
||||
"tiktoken version: 0.9.0\n",
|
||||
"torch version: 2.6.0\n"
|
||||
"blobfile version: 3.1.0\n",
|
||||
"huggingface_hub version: 0.34.4\n",
|
||||
"tiktoken version: 0.11.0\n",
|
||||
"torch version: 2.8.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -1184,16 +1184,7 @@
|
||||
"id": "3357a230-b678-4691-a238-257ee4e80185",
|
||||
"outputId": "a3652def-ea7f-46fb-f293-2a59affb71a0"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import login\n",
|
||||
"import json\n",
|
||||
@@ -1226,7 +1217,22 @@
|
||||
"id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
|
||||
"outputId": "c9836ba8-5176-4dd5-b618-6cc36fdbe1f0"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "685326b4fd014ff689e928f4200f5182",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"original/tokenizer.model: 0%| | 0.00/2.18M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from huggingface_hub import hf_hub_download\n",
|
||||
"\n",
|
||||
@@ -1422,7 +1428,64 @@
|
||||
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
|
||||
"outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3af9f77314b14682bbdd1c4921cd193e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7aeb092ad0a14b5e9aaf33bea4751490",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "20adbc86984344a39a55f012b8c18d68",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e6bb24f8ca4344dfb3870fca8c90e4fb",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from safetensors.torch import load_file\n",
|
||||
"\n",
|
||||
@@ -1511,11 +1574,14 @@
|
||||
"def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
"\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" return torch.nn.Parameter(right.clone().detach())\n",
|
||||
" else:\n",
|
||||
" return torch.nn.Parameter(torch.tensor(right))\n",
|
||||
" return left \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_weights_into_llama(model, param_config, params):\n",
|
||||
@@ -1578,7 +1644,7 @@
|
||||
" if \"lm_head.weight\" in params.keys():\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
||||
" else:\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -1733,7 +1799,64 @@
|
||||
"id": "nbvAV7vaz6yc",
|
||||
"outputId": "9e1badc9-a6c4-48b7-9125-e0810655528b"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "bdcebc6a21ae41e3bb78834b4f244fae",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "89949427bf5142c29c54978c4f0ae52a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a88b441b15714e138db6fa813dd82a47",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1c4f8df93db246d18494820bb8ec37be",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"combined_weights = {}\n",
|
||||
"\n",
|
||||
@@ -1908,7 +2031,7 @@
|
||||
"1. Grasses: Llamas love to graze on grasses, including tall grasses, short grasses, and even weeds.\n",
|
||||
"2. Hay: Hay is a staple in a llama's diet. They enjoy a variety of hays, such as timothy hay, alfalfa hay, and oat hay.\n",
|
||||
"3. Grains: Llamas may be fed grains like oats, corn, and barley as a supplement to their diet.\n",
|
||||
"4. Fruits and vegetables: Llamas enjoy fruits and vegetables like apples, carrots, and sweet potatoes as treats or additions to their diet.\n",
|
||||
"4. Fruits and vegetables: Llamas enjoy fruits and vegetables like apples, carrots, and sweet potatoes as treats or additions to their meals.\n",
|
||||
"5. Minerals:\n"
|
||||
]
|
||||
}
|
||||
@@ -2060,7 +2183,22 @@
|
||||
"metadata": {
|
||||
"id": "8xDk4chtPNU4"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ac808a4fe89d4ca89597a90f6ab83a30",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"original/tokenizer.model: 0%| | 0.00/2.18M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer_file_path = hf_hub_download(\n",
|
||||
" repo_id=\"meta-llama/Llama-3.1-8B\",\n",
|
||||
@@ -2156,7 +2294,64 @@
|
||||
"id": "u4J7IxOvOyPM",
|
||||
"outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4864b6a5f55340809e1e392cbeb5ca3c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a7c77ab5f83a4319b66856b75cf04e1e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e69661497025474b9523f5035634f788",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "f4ca91e917af4a37868e416717c9e762",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"combined_weights = {}\n",
|
||||
"\n",
|
||||
@@ -2341,7 +2536,22 @@
|
||||
"metadata": {
|
||||
"id": "jt8BKAHXRCPI"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7658da8b2a5e4273b45c35411bdba8a0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"original/tokenizer.model: 0%| | 0.00/2.18M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer_file_path = hf_hub_download(\n",
|
||||
" repo_id=\"meta-llama/Llama-3.2-1B\",\n",
|
||||
@@ -2385,9 +2595,48 @@
|
||||
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cc004791-9e28-4872-9ae9-fb51c6c83d7c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Alternatively, we can use more robust function that factors in the weight tying based on shared data pointers in memory as suggested in [#822](https://github.com/rasbt/LLMs-from-scratch/issues/822):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"id": "7aaeb28e-62ab-4711-9f07-1b32ac9dbeba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Total number of unique parameters: 1,498,482,688\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def count_unique_parameters(model):\n",
|
||||
" unique_params = set()\n",
|
||||
" total_unique_params = 0\n",
|
||||
" \n",
|
||||
" for param in model.parameters():\n",
|
||||
" if param.data_ptr() not in unique_params:\n",
|
||||
" total_unique_params += param.numel()\n",
|
||||
" unique_params.add(param.data_ptr())\n",
|
||||
" \n",
|
||||
" return total_unique_params\n",
|
||||
"\n",
|
||||
"total_params_uniq = count_unique_parameters(model)\n",
|
||||
"print(f\"\\nTotal number of unique parameters: {total_params_uniq:,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"id": "9FbCIYW7RIOe",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -2397,6 +2646,20 @@
|
||||
"outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ebf98f844b6b49669d51601cbceea91e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model.safetensors: 0%| | 0.00/2.47G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
@@ -2420,7 +2683,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 49,
|
||||
"id": "pPp5yjir6FYJ",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -2439,9 +2702,29 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Checks that the weight values are the same\n",
|
||||
"print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"id": "b2bdebe0-d2b0-4d33-8b7e-1b4f9a02ca12",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Weight tying: True\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Furthermore, check if PyTorch uses the same underlying memory\n",
|
||||
"print(\"Weight tying:\", model.tok_emb.weight.data_ptr() == model.out_head.weight.data_ptr())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
|
||||
@@ -80,10 +80,10 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"blobfile version: 3.0.0\n",
|
||||
"huggingface_hub version: 0.30.1\n",
|
||||
"tiktoken version: 0.9.0\n",
|
||||
"torch version: 2.6.0\n"
|
||||
"blobfile version: 3.1.0\n",
|
||||
"huggingface_hub version: 0.34.4\n",
|
||||
"tiktoken version: 0.11.0\n",
|
||||
"torch version: 2.8.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -470,50 +470,9 @@
|
||||
"model = Llama3Model(LLAMA32_CONFIG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1",
|
||||
"metadata": {
|
||||
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1"
|
||||
},
|
||||
"source": [
|
||||
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total number of parameters: 1,498,482,688\n",
|
||||
"\n",
|
||||
"Total number of unique parameters: 1,235,814,400\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"total_params = sum(p.numel() for p in model.parameters())\n",
|
||||
"print(f\"Total number of parameters: {total_params:,}\")\n",
|
||||
"\n",
|
||||
"# Account for weight tying\n",
|
||||
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
|
||||
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -561,6 +520,31 @@
|
||||
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "41176fb0-d58a-443a-912f-4f436564b5f8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total number of parameters: 1,498,482,688\n",
|
||||
"\n",
|
||||
"Total number of unique parameters: 1,235,814,400\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"total_params = sum(p.numel() for p in model.parameters())\n",
|
||||
"print(f\"Total number of parameters: {total_params:,}\")\n",
|
||||
"\n",
|
||||
"# Account for weight tying\n",
|
||||
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
|
||||
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
@@ -718,16 +702,7 @@
|
||||
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
|
||||
"outputId": "e6e6dc05-7330-45bc-a9a7-331919155bdd"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Uncomment and run the following code if you are executing the notebook for the first time\n",
|
||||
"\n",
|
||||
@@ -807,11 +782,14 @@
|
||||
"def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
"\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" return torch.nn.Parameter(right.clone().detach())\n",
|
||||
" else:\n",
|
||||
" return torch.nn.Parameter(torch.tensor(right))\n",
|
||||
" return left \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_weights_into_llama(model, param_config, params):\n",
|
||||
@@ -874,7 +852,7 @@
|
||||
" if \"lm_head.weight\" in params.keys():\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
||||
" else:\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
@@ -945,6 +923,42 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total number of unique parameters: 1,235,814,400\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def count_unique_parameters(model):\n",
|
||||
" unique_params = set()\n",
|
||||
" total_unique_params = 0\n",
|
||||
" \n",
|
||||
" for param in model.parameters():\n",
|
||||
" if param.data_ptr() not in unique_params:\n",
|
||||
" total_unique_params += param.numel()\n",
|
||||
" unique_params.add(param.data_ptr())\n",
|
||||
" \n",
|
||||
" return total_unique_params\n",
|
||||
"\n",
|
||||
"total_params_uniq = count_unique_parameters(model)\n",
|
||||
"print(f\"Total number of unique parameters: {total_params_uniq:,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
|
||||
"metadata": {
|
||||
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
|
||||
@@ -959,9 +973,29 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Checks that the weight values are the same\n",
|
||||
"print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "1ec6977d-ec42-42b5-bca2-3ecda791ea66",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Weight tying: True\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Furthermore, check if PyTorch uses the same underlying memory\n",
|
||||
"print(\"Weight tying:\", model.tok_emb.weight.data_ptr() == model.out_head.weight.data_ptr())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
|
||||
@@ -975,7 +1009,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 22,
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||
"metadata": {
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
||||
@@ -1034,7 +1068,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 23,
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||
"metadata": {
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
||||
@@ -1044,7 +1078,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Time: 18.20 sec\n",
|
||||
"Time: 13.21 sec\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Output text:\n",
|
||||
|
||||
@@ -768,7 +768,14 @@
|
||||
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
" \n",
|
||||
" return left \n",
|
||||
"\n",
|
||||
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
||||
"\n",
|
||||
@@ -881,9 +888,8 @@
|
||||
" if \"lm_head.weight\" in params:\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
||||
" else:\n",
|
||||
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
|
||||
" print(\"Model uses weight tying.\")\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -710,7 +710,14 @@
|
||||
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
" \n",
|
||||
" return left \n",
|
||||
"\n",
|
||||
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
||||
"\n",
|
||||
@@ -823,9 +830,8 @@
|
||||
" if \"lm_head.weight\" in params:\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
||||
" else:\n",
|
||||
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
|
||||
" print(\"Model uses weight tying.\")\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -792,7 +792,14 @@
|
||||
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
" \n",
|
||||
" return left \n",
|
||||
"\n",
|
||||
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
||||
"\n",
|
||||
@@ -873,9 +880,8 @@
|
||||
" if \"lm_head.weight\" in params:\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
||||
" else:\n",
|
||||
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
|
||||
" print(\"Model uses weight tying.\")\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -734,7 +734,14 @@
|
||||
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
" \n",
|
||||
" return left \n",
|
||||
"\n",
|
||||
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
||||
"\n",
|
||||
@@ -815,9 +822,8 @@
|
||||
" if \"lm_head.weight\" in params:\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
||||
" else:\n",
|
||||
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
|
||||
" print(\"Model uses weight tying.\")\n",
|
||||
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -427,7 +427,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
||||
"metadata": {
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
||||
@@ -842,10 +842,15 @@
|
||||
"\n",
|
||||
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(\n",
|
||||
" f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n",
|
||||
" )\n",
|
||||
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
"\n",
|
||||
" return left\n",
|
||||
"\n",
|
||||
" # Embedding weights\n",
|
||||
" if \"model.embed_tokens.weight\" in params:\n",
|
||||
@@ -948,13 +953,9 @@
|
||||
" params[\"lm_head.weight\"],\n",
|
||||
" \"lm_head.weight\",\n",
|
||||
" )\n",
|
||||
" elif \"model.embed_tokens.weight\" in params:\n",
|
||||
" # Weight tying: reuse the embedding weights\n",
|
||||
" model.out_head.weight = assign(\n",
|
||||
" model.out_head.weight,\n",
|
||||
" params[\"model.embed_tokens.weight\"],\n",
|
||||
" \"model.embed_tokens.weight\",\n",
|
||||
" )"
|
||||
" else:\n",
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1036,7 +1037,15 @@
|
||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model uses weight tying.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
@@ -1175,7 +1184,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 24,
|
||||
"id": "988f55e2-0f60-4bd8-ae55-db116ff2b26d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1186,7 +1195,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 25,
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||
"metadata": {
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
||||
@@ -1213,7 +1222,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 26,
|
||||
"id": "56c9d0cf-25e9-4375-8d5c-368fa6911fdf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
||||
@@ -235,7 +235,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 7,
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
|
||||
"metadata": {
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
|
||||
@@ -320,7 +320,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 8,
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
|
||||
"metadata": {
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
|
||||
@@ -386,7 +386,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
||||
"metadata": {
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
||||
@@ -507,7 +507,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 10,
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
|
||||
"metadata": {
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
|
||||
@@ -554,7 +554,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 11,
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
||||
"metadata": {
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
|
||||
@@ -567,7 +567,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 12,
|
||||
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -602,7 +602,7 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -621,7 +621,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 13,
|
||||
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -634,7 +634,7 @@
|
||||
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -645,7 +645,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 14,
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -676,7 +676,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 15,
|
||||
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -726,7 +726,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 16,
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
||||
"metadata": {
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
|
||||
@@ -756,7 +756,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 17,
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e",
|
||||
"metadata": {
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e"
|
||||
@@ -767,10 +767,15 @@
|
||||
"\n",
|
||||
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
||||
" if left.shape != right.shape:\n",
|
||||
" raise ValueError(\n",
|
||||
" f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n",
|
||||
" )\n",
|
||||
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
||||
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" if isinstance(right, torch.Tensor):\n",
|
||||
" left.copy_(right)\n",
|
||||
" else:\n",
|
||||
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
||||
"\n",
|
||||
" return left\n",
|
||||
"\n",
|
||||
" # Embedding weights\n",
|
||||
" if \"model.embed_tokens.weight\" in params:\n",
|
||||
@@ -873,13 +878,9 @@
|
||||
" params[\"lm_head.weight\"],\n",
|
||||
" \"lm_head.weight\",\n",
|
||||
" )\n",
|
||||
" elif \"model.embed_tokens.weight\" in params:\n",
|
||||
" # Weight tying: reuse the embedding weights\n",
|
||||
" model.out_head.weight = assign(\n",
|
||||
" model.out_head.weight,\n",
|
||||
" params[\"model.embed_tokens.weight\"],\n",
|
||||
" \"model.embed_tokens.weight\",\n",
|
||||
" )"
|
||||
" else:\n",
|
||||
" model.out_head.weight = model.tok_emb.weight\n",
|
||||
" print(\"Model uses weight tying.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -900,7 +901,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 18,
|
||||
"id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -913,7 +914,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 19,
|
||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -936,7 +937,15 @@
|
||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model uses weight tying.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
@@ -989,7 +998,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 20,
|
||||
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1019,7 +1028,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 21,
|
||||
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1037,7 +1046,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 22,
|
||||
"id": "1946b534-e3af-431a-a222-391a60bfa892",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1047,7 +1056,7 @@
|
||||
"'<bos><start_of_turn>user\\nGive me a short introduction to large language models.<end_of_turn>\\n<start_of_turn>model\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -1075,7 +1084,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 23,
|
||||
"id": "a6250333-9cf0-4f36-8e28-76be2eac1c43",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1086,7 +1095,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 24,
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||
"metadata": {
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
||||
@@ -1112,7 +1121,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 25,
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||
"metadata": {
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
||||
@@ -1176,7 +1185,7 @@
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -1190,7 +1199,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.6"
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -503,13 +503,17 @@ def assign(left, right, tensor_name="unknown"):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
|
||||
|
||||
if isinstance(right, torch.Tensor):
|
||||
return torch.nn.Parameter(right.clone().detach())
|
||||
else:
|
||||
return torch.nn.Parameter(torch.tensor(right))
|
||||
with torch.no_grad():
|
||||
if isinstance(right, torch.Tensor):
|
||||
left.copy_(right)
|
||||
else:
|
||||
left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))
|
||||
|
||||
return left
|
||||
|
||||
|
||||
def load_weights_into_llama(model, param_config, params):
|
||||
|
||||
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
for l in range(param_config["n_layers"]):
|
||||
@@ -569,5 +573,5 @@ def load_weights_into_llama(model, param_config, params):
|
||||
if "lm_head.weight" in params.keys():
|
||||
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
|
||||
else:
|
||||
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
model.out_head.weight = model.tok_emb.weight
|
||||
print("Model uses weight tying.")
|
||||
|
||||
@@ -64,7 +64,7 @@ QWEN3_CONFIG_8B = {
|
||||
"context_length": 40_960,
|
||||
"emb_dim": 4096, # 60% larger than above
|
||||
"n_heads": 32,
|
||||
"n_layers": 36,
|
||||
"n_layers": 36,
|
||||
"hidden_dim": 12288, # 26% larger than above
|
||||
"head_dim": 128,
|
||||
"qk_norm": True,
|
||||
@@ -387,7 +387,14 @@ def load_weights_into_qwen(model, param_config, params):
|
||||
def assign(left, right, tensor_name="unknown"):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
|
||||
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(right, torch.Tensor):
|
||||
left.copy_(right)
|
||||
else:
|
||||
left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))
|
||||
|
||||
return left
|
||||
|
||||
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
@@ -500,9 +507,8 @@ def load_weights_into_qwen(model, param_config, params):
|
||||
if "lm_head.weight" in params:
|
||||
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
|
||||
else:
|
||||
# Model uses weight tying, hence we reuse the embedding layer weights here
|
||||
model.out_head.weight = model.tok_emb.weight
|
||||
print("Model uses weight tying.")
|
||||
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
|
||||
class Qwen3Tokenizer:
|
||||
|
||||
Reference in New Issue
Block a user