Comment typo: head_dim -> head_dim // 2

This commit is contained in:
rasbt
2025-07-23 08:16:30 -05:00
parent b12dbf6c68
commit 4aa398c79d
7 changed files with 8 additions and 8 deletions

View File

@@ -292,7 +292,7 @@ def apply_rope(x, cos, sin, offset=0):
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)
sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation

View File

@@ -236,7 +236,7 @@ def apply_rope(x, cos, sin, offset=0):
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)
sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation

View File

@@ -260,7 +260,7 @@ def apply_rope(x, cos, sin):
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation

View File

@@ -288,7 +288,7 @@ def apply_rope(x, cos, sin):
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation