mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
More efficient angles computation in RoPE (#830)
This commit is contained in:
committed by
GitHub
parent
147dc49ab5
commit
b6cd0a312f
@@ -200,7 +200,7 @@
|
||||
" positions = torch.arange(context_length, dtype=dtype)\n",
|
||||
"\n",
|
||||
" # Compute the angles\n",
|
||||
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
|
||||
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
|
||||
"\n",
|
||||
" # Expand angles to match the head_dim\n",
|
||||
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
|
||||
|
||||
@@ -200,7 +200,7 @@
|
||||
" positions = torch.arange(context_length, dtype=dtype)\n",
|
||||
"\n",
|
||||
" # Compute the angles\n",
|
||||
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
|
||||
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
|
||||
"\n",
|
||||
" # Expand angles to match the head_dim\n",
|
||||
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
|
||||
|
||||
Reference in New Issue
Block a user