Interleaved Q and K for RoPE in Llama 2 (#750)

This commit is contained in:
Sebastian Raschka
2025-07-23 08:02:02 -05:00
committed by GitHub
parent b74ab9611e
commit 19c065b342

View File

@@ -83,7 +83,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.33.0\n",
"huggingface_hub version: 0.33.2\n",
"sentencepiece version: 0.2.0\n",
"torch version: 2.6.0\n"
]
@@ -1306,22 +1306,7 @@
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
"outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "66e777955e8748df878f118f07f38dab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"weights_file = hf_hub_download(\n",
" repo_id=\"meta-llama/Llama-2-7b\",\n",
@@ -1405,7 +1390,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 32,
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
"metadata": {
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
@@ -1422,19 +1407,40 @@
" return torch.nn.Parameter(torch.tensor(right))\n",
"\n",
"\n",
"def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",
" return (w.view(n_heads, out_dim // n_heads // 2, 2, in_dim)\n",
" .transpose(1, 2) # put axis 2 next to heads\n",
" .reshape(out_dim, in_dim))\n",
"\n",
"\n",
"def load_weights_into_llama(model, param_config, params):\n",
"\n",
" cfg = LLAMA2_CONFIG_7B\n",
" \n",
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
"\n",
" for l in range(param_config[\"n_layers\"]):\n",
"\n",
" # Load attention weights\n",
" # The original Meta/Llama checkpoints store QandK so that the two numbers \n",
" # that form one complex RoPE pair sit next to each other inside the head dimension (\"sliced\" layout).\n",
" # Our RoPE implementation, similar to the one in Hugging Face, expects an interleaved layout\n",
" # For example, with n_heads=2 and head_dim = 8\n",
" # ┌── pair 0 ──┐ ┌── pair 1 ──┐\n",
" # Meta (sliced): [ h0: r0 r1 r2 r3, h1: r0 r1 r2 r3 ]\n",
" # Ours & HF (interleaved): [ h0: r0 r0 r1 r1 r2 r2 r3 r3 , h1: ... ]\n",
" # For more information, please see the discussion in the PR: https://github.com/rasbt/LLMs-from-scratch/pull/747 \n",
" \n",
" # So, below, for q_raw and k_raw, we must reorder the checkpoint weights using the slices_to_interleave helper\n",
"\n",
" q_raw = params[f\"layers.{l}.attention.wq.weight\"]\n",
" model.trf_blocks[l].att.W_query.weight = assign(\n",
" model.trf_blocks[l].att.W_query.weight,\n",
" params[f\"layers.{l}.attention.wq.weight\"]\n",
" permute(q_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
" )\n",
" k_raw = params[f\"layers.{l}.attention.wk.weight\"]\n",
" model.trf_blocks[l].att.W_key.weight = assign(\n",
" model.trf_blocks[l].att.W_key.weight,\n",
" params[f\"layers.{l}.attention.wk.weight\"]\n",
" permute(k_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
" )\n",
" model.trf_blocks[l].att.W_value.weight = assign(\n",
" model.trf_blocks[l].att.W_value.weight,\n",
@@ -1489,7 +1495,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 33,
"id": "240987e8-a023-462e-9376-9edfb27559ec",
"metadata": {
"colab": {
@@ -1504,7 +1510,7 @@
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
" Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not\n"
]
}
],
@@ -1544,7 +1550,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 35,
"id": "nbvAV7vaz6yc",
"metadata": {
"colab": {
@@ -1568,27 +1574,14 @@
"outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" What do llamas eat?\n",
"Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
"\n",
"Llamas are herbivores, which means they eat plants for their food. They feed on a variety\n"
]
}
],
@@ -1601,6 +1594,7 @@
" local_dir=\"Llama-2-7b-chat\"\n",
")\n",
"\n",
"weights = torch.load(weights_file, weights_only=True)\n",
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
"load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
"model.to(device);\n",