mha variants

This commit is contained in:
rasbt
2024-03-06 08:30:32 -06:00
parent d4754f1bdd
commit 87fcfd9245
10 changed files with 431 additions and 10 deletions

View File

@@ -1637,7 +1637,7 @@
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
@@ -1865,7 +1865,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.12"
}
},
"nbformat": 4,