diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 7766dca..8ee68bc 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -452,10 +452,8 @@ "\n", "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", - " self, d_in, d_out, context_length, num_heads,\n", + " self, d_in, d_out, num_heads,\n", " num_kv_groups, # NEW\n", - " rope_base=10_000, # NEW\n", - " rope_config=None, # NEW\n", " dtype=None\n", " ):\n", " super().__init__()\n", @@ -645,10 +643,8 @@ "gqa = GroupedQueryAttention(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", - " context_length=max_context_len,\n", " num_heads=num_heads,\n", " num_kv_groups=8,\n", - " rope_base=llama_3_theta_base\n", ")\n", "\n", "gqa(example_batch)\n", @@ -753,11 +749,8 @@ " self.att = GroupedQueryAttention( # MultiHeadAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n", - " context_length=cfg[\"context_length\"],\n", " num_heads=cfg[\"n_heads\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"], # NEW\n", - " rope_base=cfg[\"rope_base\"], # NEW\n", - " rope_config=cfg[\"rope_freq\"], # NEW\n", " dtype=cfg[\"dtype\"]\n", " )\n", " self.ff = FeedForward(cfg)\n",