mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
fixed gqa qkv code comments (#660)
This commit is contained in:
@@ -501,9 +501,9 @@
|
||||
" ################################################\n",
|
||||
"\n",
|
||||
" # Transpose keys, values, and queries\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
"\n",
|
||||
" ##################### NEW #####################\n",
|
||||
" # Apply RoPE\n",
|
||||
|
||||
@@ -257,9 +257,9 @@
|
||||
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
|
||||
"\n",
|
||||
" # Transpose keys, values, and queries\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
||||
" keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
"\n",
|
||||
" # Apply RoPE\n",
|
||||
" keys = apply_rope(keys, cos, sin)\n",
|
||||
|
||||
Reference in New Issue
Block a user