diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index 11615cb..f935060 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -10,26 +10,16 @@ "# Efficient Multi-Head Attention Implementations" ] }, - { - "cell_type": "markdown", - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", - "metadata": { - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" - }, - "source": [ - "## Multi-head attention implementations from chapter 3" - ] - }, { "cell_type": "code", "execution_count": 1, "id": "7898551e-f582-48ac-9f66-3632abe2a93f", "metadata": { - "id": "7898551e-f582-48ac-9f66-3632abe2a93f", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "840126fe-fffa-46d4-9717-41aef89d5052" + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "outputId": "02205088-47f1-4fc1-83a4-dd0be4cd64dd" }, "outputs": [ { @@ -53,16 +43,26 @@ "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" ] }, + { + "cell_type": "markdown", + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", + "metadata": { + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" + }, + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3" + ] + }, { "cell_type": "code", "execution_count": 2, "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", "metadata": { - "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "outputId": "5af9d36b-37c9-4f6e-c370-58a46db02632", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "outputId": "a1eefc3c-21ea-463e-e75e-06af9f6262dd" }, "outputs": [ { @@ -89,16 +89,26 @@ "print(out.shape)" ] }, + { + "cell_type": "markdown", + "id": "21930804-b327-40b1-8e63-94dcad39ce7b", + "metadata": { + "id": "21930804-b327-40b1-8e63-94dcad39ce7b" + }, + "source": [ + "## 2) The multi-head attention class from chapter 3" + ] + }, { "cell_type": "code", "execution_count": 3, "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", "metadata": { - "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "outputId": "1c7ffc71-3b51-4ee8-beab-261625b1473e", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "outputId": "c66ee5fd-b0cd-4ab4-e097-4d64902ea0d0" }, "outputs": [ { @@ -132,7 +142,7 @@ "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" }, "source": [ - "## An alternative multi-head attention with combined weights" + "## 3) An alternative multi-head attention with combined weights" ] }, { @@ -158,11 +168,11 @@ "execution_count": 4, "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", "metadata": { - "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "outputId": "3c225fe5-73a9-4df0-c513-6296f4bb5261", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "outputId": "9c4ffbe8-6684-429c-b86a-b68121341a4c" }, "outputs": [ { @@ -253,7 +263,7 @@ "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" }, "source": [ - "## Multihead attention with PyTorch's scaled dot product attention" + "## 4) Multihead attention with PyTorch's scaled dot product attention" ] }, { @@ -275,7 +285,7 @@ }, "outputs": [], "source": [ - "class MultiHeadAttentionPyTorch(nn.Module):\n", + "class MHAPyTorchScaledDotProduct(nn.Module):\n", " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", " super().__init__()\n", "\n", @@ -324,11 +334,11 @@ "execution_count": 6, "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", "metadata": { - "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "outputId": "f3e7933d-16d3-45e5-f03d-610319004579", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "outputId": "027b5a66-4e17-49e8-9e80-9c70eaf201ab" }, "outputs": [ { @@ -340,7 +350,7 @@ } ], "source": [ - "mha_pytorch = MultiHeadAttentionPyTorch(\n", + "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", @@ -349,7 +359,96 @@ " qkv_bias=False\n", ").to(device)\n", "\n", - "out = mha_pytorch(embeddings)\n", + "out = mha_pytorch_scaled(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "351c318f-4835-4d74-8d58-a070222447c4", + "metadata": { + "id": "351c318f-4835-4d74-8d58-a070222447c4" + }, + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention" + ] + }, + { + "cell_type": "markdown", + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", + "metadata": { + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" + }, + "source": [ + "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "outputId": "9d9afbbd-2e85-44cb-afc9-8cb3c91e8368" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "class MHAPyTorchClass(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " self.block_size = block_size\n", + " self.multihead_attn = nn.MultiheadAttention(\n", + " embed_dim=d_out,\n", + " num_heads=num_heads,\n", + " dropout=dropout,\n", + " bias=qkv_bias,\n", + " add_bias_kv=qkv_bias,\n", + " batch_first=True\n", + " )\n", + "\n", + " self.proj = nn.Linear(d_out, d_out)\n", + " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, _ = x.shape\n", + "\n", + " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", + " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", + " if self.block_size >= num_tokens:\n", + " attn_mask = self.mask[:num_tokens, :num_tokens]\n", + " else:\n", + " attn_mask = self.mask[:self.block_size, :self.block_size]\n", + "\n", + " # attn_mask broadcasting will handle batch_size dimension implicitly\n", + " attn_output, _ = self.multihead_attn(x, x, x, attn_mask=attn_mask)\n", + "\n", + " output = self.proj(attn_output)\n", + "\n", + " return output\n", + "\n", + "\n", + "mha_pytorch_class = MHAPyTorchClass(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_class(embeddings)\n", "print(out.shape)" ] }, @@ -365,104 +464,140 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", "metadata": { - "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", - "outputId": "bb928da8-6ac0-4d15-cf12-4903d73708fc", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "41.1 ms ± 9.08 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "41.1 ms ± 12.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3\n", "%timeit mha_ch03_wrapper(embeddings)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", "metadata": { - "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", - "outputId": "54f8e05e-0cb2-4e4a-cacd-27a309a3be8b", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "6.58 ms ± 582 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "6.58 ms ± 143 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ + "## 2) The multi-head attention class from chapter 3\n", "%timeit mha_ch03(embeddings)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", "metadata": { - "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", - "outputId": "415e959e-b648-4f1e-f05e-8b8444e74bee", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "7.2 ms ± 327 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "7.19 ms ± 294 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ + "## 3) An alternative multi-head attention with combined weights\n", "%timeit mha_combined_qkv(embeddings)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", "metadata": { - "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", - "outputId": "05b7c696-1b97-4f18-8430-481bb8940b6b", "colab": { "base_uri": "https://localhost:8080/" - } + }, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "outputId": "80c6e314-0771-470e-b090-628984ce2d85" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "2.38 ms ± 386 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" + "2.37 ms ± 432 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" ] } ], "source": [ - "%timeit mha_pytorch(embeddings)" + "## 4) Multihead attention with PyTorch's scaled dot product attention\n", + "%timeit mha_pytorch_scaled(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "6.66 ms ± 397 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", + "%timeit mha_pytorch_class(embeddings)" ] } ], "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { @@ -475,14 +610,8 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "colab": { - "provenance": [], - "machine_shape": "hm", - "gpuType": "A100" - }, - "accelerator": "GPU" + "version": "3.10.6" + } }, "nbformat": 4, "nbformat_minor": 5