Add and link bonus material (#84)

This commit is contained in:
Sebastian Raschka
2024-03-23 07:27:43 -05:00
committed by GitHub
parent 35c6e12730
commit cf39abac04
12 changed files with 110 additions and 13 deletions

View File

@@ -18,7 +18,7 @@
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
},
"source": [
"# Efficient Multi-Head Attention Implementations"
"# Comparing Efficient Multi-Head Attention Implementations"
]
},
{
@@ -73,6 +73,9 @@
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 1) CausalAttention MHA wrapper class from chapter 3"
]
},
@@ -119,6 +122,9 @@
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 2) The multi-head attention class from chapter 3"
]
},
@@ -165,6 +171,9 @@
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 3) An alternative multi-head attention with combined weights"
]
},
@@ -286,6 +295,9 @@
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 4) Multihead attention with PyTorch's scaled dot product attention"
]
},
@@ -393,6 +405,9 @@
"id": "351c318f-4835-4d74-8d58-a070222447c4"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 5) Using PyTorch's torch.nn.MultiheadAttention"
]
},
@@ -488,6 +503,9 @@
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
]
},
@@ -548,6 +566,9 @@
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## Quick speed comparison (M3 Macbook Air CPU)"
]
},
@@ -706,6 +727,9 @@
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## Quick speed comparison (Nvidia A100 GPU)"
]
},
@@ -866,6 +890,10 @@
"id": "dabc6575-0316-4640-a729-e616d5c17b73"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup"
]
},
@@ -1003,7 +1031,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,