diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 908a034..7766dca 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -501,9 +501,9 @@ " ################################################\n", "\n", " # Transpose keys, values, and queries\n", - " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", + " keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", "\n", " ##################### NEW #####################\n", " # Apply RoPE\n", diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index dbec8ad..afb27c2 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -257,9 +257,9 @@ " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", "\n", " # Transpose keys, values, and queries\n", - " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", + " keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n", + " queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", "\n", " # Apply RoPE\n", " keys = apply_rope(keys, cos, sin)\n", diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index df7bc72..785e8af 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module): values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) # Transpose keys, values, and queries - keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) - values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) - queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim) + keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim) + values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim) + queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) # Apply RoPE keys = apply_rope(keys, cos, sin)