diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb index 0e8bc04..a1fc0ea 100644 --- a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -435,7 +435,7 @@ " positions = torch.arange(context_length)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index bdd1065..23a060d 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -310,7 +310,7 @@ " positions = torch.arange(context_length)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index 7fc363f..264e23c 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -180,7 +180,7 @@ " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb index 3ce53ad..36f8f9d 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb @@ -275,7 +275,7 @@ " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/11_qwen3/standalone-qwen3-moe.ipynb b/ch05/11_qwen3/standalone-qwen3-moe.ipynb index a979538..5c1a402 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe.ipynb @@ -275,7 +275,7 @@ " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb index 2753be4..ca9e15e 100644 --- a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb @@ -206,7 +206,7 @@ " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/11_qwen3/standalone-qwen3.ipynb b/ch05/11_qwen3/standalone-qwen3.ipynb index 6bd38d3..0302990 100644 --- a/ch05/11_qwen3/standalone-qwen3.ipynb +++ b/ch05/11_qwen3/standalone-qwen3.ipynb @@ -204,7 +204,7 @@ " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb b/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb index d3b6af9..a496dc0 100644 --- a/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb +++ b/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb @@ -200,7 +200,7 @@ " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/ch05/12_gemma3/standalone-gemma3.ipynb b/ch05/12_gemma3/standalone-gemma3.ipynb index e516468..5c45e20 100644 --- a/ch05/12_gemma3/standalone-gemma3.ipynb +++ b/ch05/12_gemma3/standalone-gemma3.ipynb @@ -200,7 +200,7 @@ " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index 0bfb469..71d2469 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -238,7 +238,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c positions = torch.arange(context_length, dtype=dtype) # Compute the angles - angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2) + angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2) # Expand angles to match the head_dim angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim) diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py index a68b324..8ae2a02 100644 --- a/pkg/llms_from_scratch/qwen3.py +++ b/pkg/llms_from_scratch/qwen3.py @@ -326,7 +326,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype= positions = torch.arange(context_length, dtype=dtype) # Compute the angles - angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2) + angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2) # Expand angles to match the head_dim angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)