mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Interleaved Q and K for RoPE in Llama 2 (#750)
This commit is contained in:
committed by
GitHub
parent
b74ab9611e
commit
19c065b342
@@ -83,7 +83,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"huggingface_hub version: 0.33.0\n",
|
||||
"huggingface_hub version: 0.33.2\n",
|
||||
"sentencepiece version: 0.2.0\n",
|
||||
"torch version: 2.6.0\n"
|
||||
]
|
||||
@@ -1306,22 +1306,7 @@
|
||||
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
|
||||
"outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "66e777955e8748df878f118f07f38dab",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"weights_file = hf_hub_download(\n",
|
||||
" repo_id=\"meta-llama/Llama-2-7b\",\n",
|
||||
@@ -1405,7 +1390,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 32,
|
||||
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
|
||||
"metadata": {
|
||||
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
|
||||
@@ -1422,19 +1407,40 @@
|
||||
" return torch.nn.Parameter(torch.tensor(right))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",
|
||||
" return (w.view(n_heads, out_dim // n_heads // 2, 2, in_dim)\n",
|
||||
" .transpose(1, 2) # put axis 2 next to heads\n",
|
||||
" .reshape(out_dim, in_dim))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_weights_into_llama(model, param_config, params):\n",
|
||||
"\n",
|
||||
" cfg = LLAMA2_CONFIG_7B\n",
|
||||
" \n",
|
||||
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
|
||||
"\n",
|
||||
" for l in range(param_config[\"n_layers\"]):\n",
|
||||
"\n",
|
||||
" # Load attention weights\n",
|
||||
" # The original Meta/Llama checkpoints store Q and K so that the two numbers \n",
|
||||
" # that form one complex RoPE pair sit next to each other inside the head dimension (\"sliced\" layout).\n",
|
||||
" # Our RoPE implementation, similar to the one in Hugging Face, expects an interleaved layout\n",
|
||||
" # For example, with n_heads=2 and head_dim = 8\n",
|
||||
" # ┌── pair 0 ──┐ ┌── pair 1 ──┐\n",
|
||||
" # Meta (sliced): [ h0: r0 r1 r2 r3, h1: r0 r1 r2 r3 ]\n",
|
||||
" # Ours & HF (interleaved): [ h0: r0 r0 r1 r1 r2 r2 r3 r3 , h1: ... ]\n",
|
||||
" # For more information, please see the discussion in the PR: https://github.com/rasbt/LLMs-from-scratch/pull/747 \n",
|
||||
" \n",
|
||||
" # So, below, for q_raw and k_raw, we must re‑order the checkpoint weights using the slices_to_interleave helper\n",
|
||||
"\n",
|
||||
" q_raw = params[f\"layers.{l}.attention.wq.weight\"]\n",
|
||||
" model.trf_blocks[l].att.W_query.weight = assign(\n",
|
||||
" model.trf_blocks[l].att.W_query.weight,\n",
|
||||
" params[f\"layers.{l}.attention.wq.weight\"]\n",
|
||||
" permute(q_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
|
||||
" )\n",
|
||||
" k_raw = params[f\"layers.{l}.attention.wk.weight\"]\n",
|
||||
" model.trf_blocks[l].att.W_key.weight = assign(\n",
|
||||
" model.trf_blocks[l].att.W_key.weight,\n",
|
||||
" params[f\"layers.{l}.attention.wk.weight\"]\n",
|
||||
" permute(k_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
|
||||
" )\n",
|
||||
" model.trf_blocks[l].att.W_value.weight = assign(\n",
|
||||
" model.trf_blocks[l].att.W_value.weight,\n",
|
||||
@@ -1489,7 +1495,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 33,
|
||||
"id": "240987e8-a023-462e-9376-9edfb27559ec",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -1504,7 +1510,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
|
||||
" Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -1544,7 +1550,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 35,
|
||||
"id": "nbvAV7vaz6yc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -1568,27 +1574,14 @@
|
||||
"outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" What do llamas eat?\n",
|
||||
"Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
|
||||
"\n",
|
||||
"Llamas are herbivores, which means they eat plants for their food. They feed on a variety\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -1601,6 +1594,7 @@
|
||||
" local_dir=\"Llama-2-7b-chat\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"weights = torch.load(weights_file, weights_only=True)\n",
|
||||
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
|
||||
"load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
|
||||
"model.to(device);\n",
|
||||
|
||||
Reference in New Issue
Block a user