mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Note about RoPE usage (#839)
* Note about devcontainer root usage * Add note about RoPE implementation
This commit is contained in:
committed by
GitHub
parent
42c130623b
commit
2aa8e8130d
@@ -205,6 +205,58 @@ class GroupedQueryAttention(nn.Module):
|
||||
return context_vec
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# RoPE implementation summary
|
||||
#
|
||||
#
|
||||
# There are two common styles to implement RoPE, which are
|
||||
# mathematically equivalent;
|
||||
# they mainly differ in how the rotation matrix pairs dimensions.
|
||||
#
|
||||
# 1) Split-halves style (this repo, Hugging Face Transformers):
|
||||
#
|
||||
# For hidden dim d = 8 (example):
|
||||
#
|
||||
# [ x0 x1 x2 x3 x4 x5 x6 x7 ]
|
||||
# │ │ │ │ │ │ │ │
|
||||
# ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
|
||||
# cos cos cos cos sin sin sin sin
|
||||
#
|
||||
# Rotation matrix:
|
||||
#
|
||||
# [ cosθ -sinθ 0 0 ... ]
|
||||
# [ sinθ cosθ 0 0 ... ]
|
||||
# [ 0 0 cosθ -sinθ ... ]
|
||||
# [ 0 0 sinθ cosθ ... ]
|
||||
# ...
|
||||
#
|
||||
# Here, the embedding dims are split into two halves and then
|
||||
# each one is rotated in blocks.
|
||||
#
|
||||
#
|
||||
# 2) Interleaved (even/odd) style (original paper, Llama repo):
|
||||
#
|
||||
# For hidden dim d = 8 (example):
|
||||
#
|
||||
# [ x0 x1 x2 x3 x4 x5 x6 x7 ]
|
||||
# │ │ │ │ │ │ │ │
|
||||
# ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
|
||||
# cos sin cos sin cos sin cos sin
|
||||
#
|
||||
# Rotation matrix:
|
||||
# [ cosθ -sinθ 0 0 ... ]
|
||||
# [ sinθ cosθ 0 0 ... ]
|
||||
# [ 0 0 cosθ -sinθ ... ]
|
||||
# [ 0 0 sinθ cosθ ... ]
|
||||
# ...
|
||||
#
|
||||
# Here, embedding dims are interleaved as even/odd cosine/sine pairs.
|
||||
#
|
||||
# Both layouts encode the same relative positions; the only difference is how
|
||||
# dimensions are paired.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
|
||||
assert head_dim % 2 == 0, "Embedding dimension must be even"
|
||||
|
||||
|
||||
@@ -316,6 +316,58 @@ class GroupedQueryAttention(nn.Module):
|
||||
return self.out_proj(context)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# RoPE implementation summary
|
||||
#
|
||||
#
|
||||
# There are two common styles to implement RoPE, which are
|
||||
# mathematically equivalent;
|
||||
# they mainly differ in how the rotation matrix pairs dimensions.
|
||||
#
|
||||
# 1) Split-halves style (this repo, Hugging Face Transformers):
|
||||
#
|
||||
# For hidden dim d = 8 (example):
|
||||
#
|
||||
# [ x0 x1 x2 x3 x4 x5 x6 x7 ]
|
||||
# │ │ │ │ │ │ │ │
|
||||
# ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
|
||||
# cos cos cos cos sin sin sin sin
|
||||
#
|
||||
# Rotation matrix:
|
||||
#
|
||||
# [ cosθ -sinθ 0 0 ... ]
|
||||
# [ sinθ cosθ 0 0 ... ]
|
||||
# [ 0 0 cosθ -sinθ ... ]
|
||||
# [ 0 0 sinθ cosθ ... ]
|
||||
# ...
|
||||
#
|
||||
# Here, the embedding dims are split into two halves and then
|
||||
# each one is rotated in blocks.
|
||||
#
|
||||
#
|
||||
# 2) Interleaved (even/odd) style (original paper, Llama repo):
|
||||
#
|
||||
# For hidden dim d = 8 (example):
|
||||
#
|
||||
# [ x0 x1 x2 x3 x4 x5 x6 x7 ]
|
||||
# │ │ │ │ │ │ │ │
|
||||
# ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
|
||||
# cos sin cos sin cos sin cos sin
|
||||
#
|
||||
# Rotation matrix:
|
||||
# [ cosθ -sinθ 0 0 ... ]
|
||||
# [ sinθ cosθ 0 0 ... ]
|
||||
# [ 0 0 cosθ -sinθ ... ]
|
||||
# [ 0 0 sinθ cosθ ... ]
|
||||
# ...
|
||||
#
|
||||
# Here, embedding dims are interleaved as even/odd cosine/sine pairs.
|
||||
#
|
||||
# Both layouts encode the same relative positions; the only difference is how
|
||||
# dimensions are paired.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
|
||||
assert head_dim % 2 == 0, "Embedding dimension must be even"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user