add pytorch mha

This commit is contained in:
rasbt
2024-03-08 09:30:55 -06:00
parent 3beaea46ce
commit 5643c88db9

View File

@@ -10,26 +10,16 @@
"# Efficient Multi-Head Attention Implementations"
]
},
{
"cell_type": "markdown",
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
"metadata": {
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [
"## Multi-head attention implementations from chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "840126fe-fffa-46d4-9717-41aef89d5052"
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"outputId": "02205088-47f1-4fc1-83a4-dd0be4cd64dd"
},
"outputs": [
{
@@ -53,16 +43,26 @@
"embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
]
},
{
"cell_type": "markdown",
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
"metadata": {
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"metadata": {
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"outputId": "5af9d36b-37c9-4f6e-c370-58a46db02632",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"outputId": "a1eefc3c-21ea-463e-e75e-06af9f6262dd"
},
"outputs": [
{
@@ -89,16 +89,26 @@
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "21930804-b327-40b1-8e63-94dcad39ce7b",
"metadata": {
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
},
"source": [
"## 2) The multi-head attention class from chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"metadata": {
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"outputId": "1c7ffc71-3b51-4ee8-beab-261625b1473e",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"outputId": "c66ee5fd-b0cd-4ab4-e097-4d64902ea0d0"
},
"outputs": [
{
@@ -132,7 +142,7 @@
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
},
"source": [
"## An alternative multi-head attention with combined weights"
"## 3) An alternative multi-head attention with combined weights"
]
},
{
@@ -158,11 +168,11 @@
"execution_count": 4,
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"metadata": {
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"outputId": "3c225fe5-73a9-4df0-c513-6296f4bb5261",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"outputId": "9c4ffbe8-6684-429c-b86a-b68121341a4c"
},
"outputs": [
{
@@ -253,7 +263,7 @@
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
},
"source": [
"## Multihead attention with PyTorch's scaled dot product attention"
"## 4) Multihead attention with PyTorch's scaled dot product attention"
]
},
{
@@ -275,7 +285,7 @@
},
"outputs": [],
"source": [
"class MultiHeadAttentionPyTorch(nn.Module):\n",
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
@@ -324,11 +334,11 @@
"execution_count": 6,
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"metadata": {
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"outputId": "f3e7933d-16d3-45e5-f03d-610319004579",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"outputId": "027b5a66-4e17-49e8-9e80-9c70eaf201ab"
},
"outputs": [
{
@@ -340,7 +350,7 @@
}
],
"source": [
"mha_pytorch = MultiHeadAttentionPyTorch(\n",
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
@@ -349,7 +359,96 @@
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch(embeddings)\n",
"out = mha_pytorch_scaled(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "351c318f-4835-4d74-8d58-a070222447c4",
"metadata": {
"id": "351c318f-4835-4d74-8d58-a070222447c4"
},
"source": [
"## 5) Using PyTorch's torch.nn.MultiheadAttention"
]
},
{
"cell_type": "markdown",
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965",
"metadata": {
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965"
},
"source": [
"- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
"outputId": "9d9afbbd-2e85-44cb-afc9-8cb3c91e8368"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"class MHAPyTorchClass(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" self.block_size = block_size\n",
" self.multihead_attn = nn.MultiheadAttention(\n",
" embed_dim=d_out,\n",
" num_heads=num_heads,\n",
" dropout=dropout,\n",
" bias=qkv_bias,\n",
" add_bias_kv=qkv_bias,\n",
" batch_first=True\n",
" )\n",
"\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, _ = x.shape\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.block_size >= num_tokens:\n",
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.mask[:self.block_size, :self.block_size]\n",
"\n",
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
" attn_output, _ = self.multihead_attn(x, x, x, attn_mask=attn_mask)\n",
"\n",
" output = self.proj(attn_output)\n",
"\n",
" return output\n",
"\n",
"\n",
"mha_pytorch_class = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class(embeddings)\n",
"print(out.shape)"
]
},
@@ -365,104 +464,140 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"metadata": {
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"outputId": "bb928da8-6ac0-4d15-cf12-4903d73708fc",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"41.1 ms ± 9.08 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"41.1 ms ± 12.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"metadata": {
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"outputId": "54f8e05e-0cb2-4e4a-cacd-27a309a3be8b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"6.58 ms ± 582 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"6.58 ms ± 143 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"metadata": {
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"outputId": "415e959e-b648-4f1e-f05e-8b8444e74bee",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"outputId": "92b634f8-43f8-468f-87a1-bb774b64c212"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"7.2 ms ± 327 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"7.19 ms ± 294 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"metadata": {
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"outputId": "05b7c696-1b97-4f18-8430-481bb8940b6b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"outputId": "80c6e314-0771-470e-b090-628984ce2d85"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.38 ms ± 386 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
"2.37 ms ± 432 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%timeit mha_pytorch(embeddings)"
"## 4) Multihead attention with PyTorch's scaled dot product attention\n",
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"6.66 ms ± 397 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"## 5) Using PyTorch's torch.nn.MultiheadAttention\n",
"%timeit mha_pytorch_class(embeddings)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"machine_shape": "hm",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
@@ -475,14 +610,8 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100"
},
"accelerator": "GPU"
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5