Improve MHA einsum (#775)

This commit is contained in:
Sebastian Raschka
2025-08-19 10:38:15 -05:00
committed by GitHub
parent 80d4732456
commit 8c1f9ccf54

View File

@@ -58,7 +58,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.6.0+cu124\n"
"PyTorch version: 2.8.0\n"
]
}
],
@@ -89,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "1db27f43-86f4-478f-89df-fbc2182a129b",
"metadata": {
"id": "1db27f43-86f4-478f-89df-fbc2182a129b"
@@ -114,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"metadata": {
"colab": {
@@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"metadata": {
"colab": {
@@ -326,7 +326,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"metadata": {
"colab": {
@@ -434,7 +434,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "92481814-068d-439b-a65c-b1310ebbe0aa",
"metadata": {
"colab": {
@@ -466,7 +466,6 @@
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads\n",
"\n",
" # Initialize parameters for Q, K, V\n",
" self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n",
" self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n",
" self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n",
@@ -483,8 +482,6 @@
" self.out_proj = nn.Linear(d_out, d_out)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" # Initialize parameters\n",
" self.reset_parameters()\n",
"\n",
"\n",