mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
mha variants
This commit is contained in:
@@ -1637,7 +1637,7 @@
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
@@ -1865,7 +1865,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -243,7 +243,7 @@
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
@@ -342,7 +342,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user