remove redundant context_length in GQA

This commit is contained in:
rasbt
2025-03-31 16:49:10 -05:00
parent 06ebac3c34
commit 4715dc3be5

View File

@@ -233,7 +233,7 @@
"source": [ "source": [
"class GroupedQueryAttention(nn.Module):\n", "class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n", " def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n", " self, d_in, d_out, num_heads,\n",
" num_kv_groups,\n", " num_kv_groups,\n",
" dtype=None\n", " dtype=None\n",
" ):\n", " ):\n",
@@ -320,7 +320,6 @@
" self.att = GroupedQueryAttention(\n", " self.att = GroupedQueryAttention(\n",
" d_in=cfg[\"emb_dim\"],\n", " d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n",
" context_length=cfg[\"context_length\"],\n",
" num_heads=cfg[\"n_heads\"],\n", " num_heads=cfg[\"n_heads\"],\n",
" num_kv_groups=cfg[\"n_kv_groups\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"],\n",
" dtype=cfg[\"dtype\"]\n", " dtype=cfg[\"dtype\"]\n",