fixed gqa qkv code comments (#660)

This commit is contained in:
Daniel Kleine
2025-06-13 15:21:28 +02:00
committed by GitHub
parent 7632eb018b
commit c2cfb47b1a
3 changed files with 9 additions and 9 deletions

View File

@@ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module):
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
# Transpose keys, values, and queries
keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)
keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
# Apply RoPE
keys = apply_rope(keys, cos, sin)