fixed gqa qkv code comments (#660)

This commit is contained in:
Daniel Kleine
2025-06-13 15:21:28 +02:00
committed by GitHub
parent 7632eb018b
commit c2cfb47b1a
3 changed files with 9 additions and 9 deletions

View File

@@ -501,9 +501,9 @@
" ################################################\n",
"\n",
" # Transpose keys, values, and queries\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
"\n",
" ##################### NEW #####################\n",
" # Apply RoPE\n",

View File

@@ -257,9 +257,9 @@
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
"\n",
" # Transpose keys, values, and queries\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
"\n",
" # Apply RoPE\n",
" keys = apply_rope(keys, cos, sin)\n",