mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Improve MHA einsum (#775)
This commit is contained in:
committed by
GitHub
parent
80d4732456
commit
8c1f9ccf54
@@ -58,7 +58,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"PyTorch version: 2.6.0+cu124\n"
|
||||
"PyTorch version: 2.8.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -89,7 +89,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"id": "1db27f43-86f4-478f-89df-fbc2182a129b",
|
||||
"metadata": {
|
||||
"id": "1db27f43-86f4-478f-89df-fbc2182a129b"
|
||||
@@ -114,7 +114,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -205,7 +205,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -326,7 +326,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -434,7 +434,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "92481814-068d-439b-a65c-b1310ebbe0aa",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -466,7 +466,6 @@
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.head_dim = d_out // num_heads\n",
|
||||
"\n",
|
||||
" # Initialize parameters for Q, K, V\n",
|
||||
" self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n",
|
||||
" self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n",
|
||||
" self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n",
|
||||
@@ -483,8 +482,6 @@
|
||||
" self.out_proj = nn.Linear(d_out, d_out)\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
||||
"\n",
|
||||
" # Initialize parameters\n",
|
||||
" self.reset_parameters()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
Reference in New Issue
Block a user