From 87fcfd9245261f82377487bbf98118b8fbf8bca1 Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 6 Mar 2024 08:30:32 -0600 Subject: [PATCH] mha variants --- README.md | 7 +- ch03/01_main-chapter-code/ch03.ipynb | 4 +- .../multihead-attention.ipynb | 4 +- .../README.md | 3 + .../ch03.py | 58 +++ .../mha-implementations.ipynb | 356 ++++++++++++++++++ ch03/README.md | 3 +- ch04/01_main-chapter-code/gpt.py | 2 +- .../01_main-chapter-code/previous_chapters.py | 2 +- ch05/02_hparam_tuning/previous_chapters.py | 2 +- 10 files changed, 431 insertions(+), 10 deletions(-) create mode 100644 ch03/02_bonus_efficient-multihead-attention/README.md create mode 100644 ch03/02_bonus_efficient-multihead-attention/ch03.py create mode 100644 ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb diff --git a/README.md b/README.md index 2440e75..c223930 100644 --- a/README.md +++ b/README.md @@ -41,12 +41,15 @@ Alternatively, you can view this and other files on GitHub at [https://github.co | Ch 6: Finetuning for Text Classification | Q2 2024 | ... | | Ch 7: Finetuning with Human Feedback | Q2 2024 | ... | | Ch 8: Using Large Language Models in Practice | Q2/3 2024 | ... | -| Appendix A: Introduction to PyTorch* | - [code-part1.ipynb](appendix-A/03_main-chapter-code/code-part1.ipynb)
- [code-part2.ipynb](appendix-A/03_main-chapter-code/code-part2.ipynb)
- [DDP-script.py](appendix-A/03_main-chapter-code/DDP-script.py)
- [exercise-solutions.ipynb](appendix-A/03_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) | +| Appendix A: Introduction to PyTorch | - [code-part1.ipynb](appendix-A/03_main-chapter-code/code-part1.ipynb)
- [code-part2.ipynb](appendix-A/03_main-chapter-code/code-part2.ipynb)
- [DDP-script.py](appendix-A/03_main-chapter-code/DDP-script.py)
- [exercise-solutions.ipynb](appendix-A/03_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) | +| Appendix B: References and Further Reading | No code | | +| Appendix C: Exercises | No code | | +
> [!TIP] -> Please see [this](appendix-A/01_optional-python-setup-preferences) and [this](appendix-A/02_installing-python-libraries) folder if you need more guidance on installing Python and Python packages.) +> Please see [this](appendix-A/01_optional-python-setup-preferences) and [this](appendix-A/02_installing-python-libraries) folder if you need more guidance on installing Python and Python packages. diff --git a/ch03/01_main-chapter-code/ch03.ipynb b/ch03/01_main-chapter-code/ch03.ipynb index 614ae08..734bcdf 100644 --- a/ch03/01_main-chapter-code/ch03.ipynb +++ b/ch03/01_main-chapter-code/ch03.ipynb @@ -1637,7 +1637,7 @@ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n", " super().__init__()\n", - " assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n", + " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", "\n", " self.d_out = d_out\n", " self.num_heads = num_heads\n", @@ -1865,7 +1865,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/ch03/01_main-chapter-code/multihead-attention.ipynb b/ch03/01_main-chapter-code/multihead-attention.ipynb index c981b2b..2a072d3 100644 --- a/ch03/01_main-chapter-code/multihead-attention.ipynb +++ b/ch03/01_main-chapter-code/multihead-attention.ipynb @@ -243,7 +243,7 @@ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n", " super().__init__()\n", - " assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n", + " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", "\n", " self.d_out = d_out\n", " self.num_heads = num_heads\n", @@ -342,7 +342,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/ch03/02_bonus_efficient-multihead-attention/README.md b/ch03/02_bonus_efficient-multihead-attention/README.md new file mode 100644 index 0000000..e76a634 --- /dev/null +++ b/ch03/02_bonus_efficient-multihead-attention/README.md @@ -0,0 +1,3 @@ +# More Efficient Multi-Head Attention Implementations + +- [mha-implementations.ipynb](mha-implementations.ipynb) contains and compares different implementations of multi-head attention \ No newline at end of file diff --git a/ch03/02_bonus_efficient-multihead-attention/ch03.py b/ch03/02_bonus_efficient-multihead-attention/ch03.py new file mode 100644 index 0000000..f7343d7 --- /dev/null +++ b/ch03/02_bonus_efficient-multihead-attention/ch03.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn + + +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) + + def forward(self, x): + b, num_tokens, d_in = x.shape + + keys = self.W_key(x) # Shape: (b, num_tokens, d_out) + queries = self.W_query(x) + values = self.W_value(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) + values = values.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + # Original mask truncated to the number of tokens and converted to boolean + mask_bool = self.mask.bool()[:num_tokens, :num_tokens] + # Unsqueeze the mask to match dimensions + mask_unsqueezed = mask_bool.unsqueeze(0) + # Use the unsqueezed mask to fill attention scores + attn_scores.masked_fill_(mask_unsqueezed, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec \ No newline at end of file diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb new file mode 100644 index 0000000..89e1c40 --- /dev/null +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6f678e62-7bcb-4405-86ae-dce94f494303", + "metadata": {}, + "source": [ + "# Appendix D: Efficient Multi-Head Attention Implementations" + ] + }, + { + "cell_type": "markdown", + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", + "metadata": {}, + "source": [ + "## Multi-head attention implementation from chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "torch.manual_seed(123)\n", + "\n", + "batch_size = 8\n", + "context_len = 1024\n", + "embed_dim = 768\n", + "embeddings = torch.randn((batch_size, context_len, embed_dim))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttention as Ch03_MHA\n", + "\n", + "mha_ch03 = Ch03_MHA(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ")\n", + "\n", + "out = mha_ch03(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", + "metadata": {}, + "source": [ + "## An alternative multi-head attention with combined weights" + ] + }, + { + "cell_type": "markdown", + "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", + "metadata": {}, + "source": [ + "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", + "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", + "\n", + " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " \n", + "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", + "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class MultiHeadAttentionAlt(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", + " queries, keys, values = qkv.unbind(0)\n", + "\n", + " # (b, num_head, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", + " attn_scores = queries @ keys.transpose(-2, -1)\n", + " attn_scores = attn_scores.masked_fill(\n", + " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", + " )\n", + " \n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", + " attn_weights = self.dropout(attn_weights)\n", + "\n", + " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", + " context_vec = attn_weights @ values\n", + "\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", + " context_vec = context_vec.transpose(1, 2)\n", + "\n", + " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", + " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", + "\n", + " context_vec = self.proj(context_vec)\n", + "\n", + " return context_vec\n", + "\n", + "\n", + "mha_alt = MultiHeadAttentionAlt(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ")\n", + "\n", + "out = mha_alt(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", + "metadata": {}, + "source": [ + "## Multihead attention with PyTorch's scaled dot product attention" + ] + }, + { + "cell_type": "markdown", + "id": "f78e346f-3b85-44e6-9feb-f01131381148", + "metadata": {}, + "source": [ + "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", + "metadata": {}, + "outputs": [], + "source": [ + "class MultiHeadAttentionPyTorch(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + " self.d_out = d_out\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = dropout\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", + " q, k, v = qkv.unbind(0)\n", + "\n", + " use_dropout = 0. if not self.training else self.dropout\n", + " context_vec = torch.nn.functional.scaled_dot_product_attention(q, k, v, \n", + " attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", + "\n", + " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", + " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", + "\n", + " return context_vec" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "mha_pytorch = MultiHeadAttentionPyTorch(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ")\n", + "\n", + "out = mha_pytorch(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "8877de71-f84f-4f6d-bc87-7552013b6301", + "metadata": {}, + "source": [ + "## Speed comparison" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "253 ms ± 9.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit mha_ch03(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "309 ms ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit mha_alt(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "90.4 ms ± 719 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit mha_pytorch(embeddings)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch03/README.md b/ch03/README.md index 846044b..b781e2e 100644 --- a/ch03/README.md +++ b/ch03/README.md @@ -1,3 +1,4 @@ # Chapter 3: Coding Attention Mechanisms -- [01_main-chapter-code](01_main-chapter-code) contains the main chapter code. \ No newline at end of file +- [01_main-chapter-code](01_main-chapter-code) contains the main chapter code. +- [02_bonus_efficient-multihead-attention](02_bonus_efficient-multihead-attention) implements and compares different implementation variants of multihead-attention \ No newline at end of file diff --git a/ch04/01_main-chapter-code/gpt.py b/ch04/01_main-chapter-code/gpt.py index f6dde98..8390ddb 100644 --- a/ch04/01_main-chapter-code/gpt.py +++ b/ch04/01_main-chapter-code/gpt.py @@ -56,7 +56,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256, class MultiHeadAttention(nn.Module): def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): super().__init__() - assert d_out % num_heads == 0, "d_out must be divisible by n_heads" + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads diff --git a/ch04/01_main-chapter-code/previous_chapters.py b/ch04/01_main-chapter-code/previous_chapters.py index 926dba4..21b2edf 100644 --- a/ch04/01_main-chapter-code/previous_chapters.py +++ b/ch04/01_main-chapter-code/previous_chapters.py @@ -45,7 +45,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256, class MultiHeadAttention(nn.Module): def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): super().__init__() - assert d_out % num_heads == 0, "d_out must be divisible by n_heads" + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads diff --git a/ch05/02_hparam_tuning/previous_chapters.py b/ch05/02_hparam_tuning/previous_chapters.py index 6b1a00e..fc8f64b 100644 --- a/ch05/02_hparam_tuning/previous_chapters.py +++ b/ch05/02_hparam_tuning/previous_chapters.py @@ -56,7 +56,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256, class MultiHeadAttention(nn.Module): def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): super().__init__() - assert d_out % num_heads == 0, "d_out must be divisible by n_heads" + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads