diff --git a/.gitignore b/.gitignore
index 8105822..5a7d6d6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,9 @@ appendix-D/01_main-chapter-code/3.pdf
appendix-E/01_main-chapter-code/loss-plot.pdf
+ch04/04_gqa/kv_bytes_vs_context_length.pdf
+ch04/04_gqa/savings_vs_n_kv_groups.pdf
+
ch05/01_main-chapter-code/loss-plot.pdf
ch05/01_main-chapter-code/temperature-plot.pdf
ch05/01_main-chapter-code/the-verdict.txt
diff --git a/README.md b/README.md
index 2ad5e7a..f6bd229 100644
--- a/README.md
+++ b/README.md
@@ -168,6 +168,7 @@ Several folders contain optional materials as a bonus for interested readers:
- **Chapter 4: Implementing a GPT model from scratch**
- [FLOPS Analysis](ch04/02_performance-analysis/flops-analysis.ipynb)
- [KV Cache](ch04/03_kv-cache)
+ - [Grouped-Query Attention](ch04/04_gqa)
- **Chapter 5: Pretraining on unlabeled data:**
- [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/)
- [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)
diff --git a/ch04/04_gqa/README.md b/ch04/04_gqa/README.md
new file mode 100644
index 0000000..100702e
--- /dev/null
+++ b/ch04/04_gqa/README.md
@@ -0,0 +1,126 @@
+# Grouped-Query Attention (GQA)
+
+This bonus material illustrates the memory savings when using Grouped-Query Attention (GQA) over regular Multi-Head Attention (MHA).
+
+
+
+
+## Introduction
+
+
+Grouped-Query Attention (GQA) has become the new standard replacement for a more compute- and parameter-efficient alternative to Multi-Head Attention (MHA) in recent years. Note that it's not new and goes back to the 2023 [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245). And even the larger variants in the good old Llama 2 series used it.
+
+Here's a brief GQA summary. Unlike MHA, where each head also has its own set of keys and values, to reduce memory usage, GQA groups multiple heads to share the same key and value projections.
+
+For example, as further illustrated in the figure below, if there are 3 key-value groups and 6 attention heads, then heads 1 and 2 share one set of keys and values, while heads 3 and 4, as well as heads 5 and 6, share another, respectively.
+
+
+
+
+
+
+
+This sharing of keys and values reduces the total number of key and value computations, which leads to lower memory usage and improved efficiency.
+
+So, to summarize, the core idea behind GQA is to reduce the number of key and value heads by sharing them across multiple query heads. This (1) lowers the model's parameter count and (2) reduces the memory bandwidth usage for key and value tensors during inference since fewer keys and values need to be stored and retrieved from the KV cache.
+
+While GQA is mainly a computational-efficiency workaround for MHA, ablation studies (such as those in the [original GQA paper](https://arxiv.org/abs/2305.13245) and the [Llama 2 paper](https://arxiv.org/abs/2307.09288)) show it performs comparably to standard MHA in terms of LLM modeling performance.
+
+However, this assumes that the number of key-value groups is chosen carefully. However, if we set the number of key-value heads equal to the number of heads (this special case is known as multi-query attention), it will negatively affect the modeling performance.
+
+
+
+
+## GQA Memory Savings
+
+The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula:
+
+bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads
+
+You can use the [memory_estimator.py](memory_estimator.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA:
+
+```bash
+➜ uv run memory_estimator.py \
+ --emb_dim 4096 --n_heads 32 --n_layers 32 \
+ --context_length 32768 --n_kv_groups 4 \
+ --batch_size 1 --dtype bf16
+==== Config ====
+context_length : 32768
+emb_dim : 4096
+n_heads : 32
+n_layers : 32
+n_kv_groups : 4
+batch_size : 1
+dtype : bf16 (2 Bytes/elem)
+head_dim : 128
+GQA n_kv_heads : 8
+
+==== KV-cache totals across all layers ====
+MHA total KV cache : 17.18 GB
+GQA total KV cache : 4.29 GB
+Ratio (MHA / GQA) : 4.00x
+Savings (GQA vs MHA): 75.00%
+```
+
+The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes:
+
+
+
+
+
+
+
+And the following plot shows how the KV cache size grows with an increasing context length:
+
+
+
+
+
+
+
+You can reproduce these plots via `uv run plot_memory_estimates.py`.
+
+
+
+
+## GQA Code Examples
+
+The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_gqa.py](gpt_with_kv_gqa.py) scripts in this folder provide hands-on examples for comparing the MHA and GQA memory usage in the context of a GPT model implementation.
+
+Note that GQA is also used in the [Llama 3](../../ch05/07_gpt_to_llama), [Gemma 3](../../ch05/12_gemma3), and [Qwen3](../../ch05/11_qwen3) bonus materials. However, for simplicity, the code scripts in this folder modify the GPT architecture, which traditionally didn't use GQA.
+
+Note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it.
+
+Also, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache) so the memory savings are more pronounced.
+
+```bash
+uv run gpt_with_kv_mha.py \
+--max_new_tokens 32768 \
+--n_heads 24 \
+--n_layers 12
+
+...
+
+Time: 453.81 sec
+72 tokens/sec
+Max memory allocated: 1.54 GB
+```
+
+```bash
+uv run gpt_with_kv_gqa.py \
+--max_new_tokens 32768 \
+--n_heads 24 \
+--n_layers 12 \
+--n_kv_groups 4
+
+...
+
+Time: 516.33 sec
+63 tokens/sec
+Max memory allocated: 0.63 GB
+```
+
+The reason why we are not seeing such a big saving as in the plots above is 2-fold:
+
+1. I use a smaller configuration to have the model finish the generation in a reasonable time.
+2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis).
diff --git a/ch04/04_gqa/gpt_with_kv_gqa.py b/ch04/04_gqa/gpt_with_kv_gqa.py
new file mode 100644
index 0000000..1f830ed
--- /dev/null
+++ b/ch04/04_gqa/gpt_with_kv_gqa.py
@@ -0,0 +1,385 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+# This file collects all the relevant code that we covered thus far
+# throughout Chapters 3-4.
+# This file can be run as a standalone script.
+
+import argparse
+import time
+import tiktoken
+import torch
+import torch.nn as nn
+
+
+#####################################
+# NEW: GQA instead of MHA
+#####################################
+class GroupedQueryAttention(nn.Module):
+ def __init__(
+ self, d_in, d_out, dropout, num_heads, num_kv_groups, dtype=None, qkv_bias=False
+ ):
+ super().__init__()
+ assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
+ assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
+
+ self.d_out = d_out
+ self.num_heads = num_heads
+ self.head_dim = d_out // num_heads
+
+ self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=qkv_bias, dtype=dtype)
+ self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=qkv_bias, dtype=dtype)
+ self.num_kv_groups = num_kv_groups
+ self.group_size = num_heads // num_kv_groups
+
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias, dtype=dtype)
+ self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
+ self.dropout = nn.Dropout(dropout)
+
+ self.register_buffer("cache_k", None, persistent=False)
+ self.register_buffer("cache_v", None, persistent=False)
+ self.ptr_current_pos = 0
+
+ def forward(self, x, use_cache=False):
+ b, num_tokens, _ = x.shape
+
+ # Apply projections
+ queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
+ keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
+ values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
+
+ # Reshape
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
+ keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
+ values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
+
+ if use_cache:
+ if self.cache_k is None:
+ self.cache_k, self.cache_v = keys_new, values_new
+ else:
+ self.cache_k = torch.cat([self.cache_k, keys_new], dim=2)
+ self.cache_v = torch.cat([self.cache_v, values_new], dim=2)
+ keys_base, values_base = self.cache_k, self.cache_v
+ else:
+ keys_base, values_base = keys_new, values_new
+ if self.cache_k is not None or self.cache_v is not None:
+ self.cache_k, self.cache_v = None, None
+ self.ptr_current_pos = 0
+
+ # Expand keys and values to match the number of heads
+ # Shape: (b, num_heads, num_tokens, head_dim)
+ keys = keys_base.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
+ values = values_base.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
+ # For example, before repeat_interleave along dim=1 (query groups):
+ # [K1, K2]
+ # After repeat_interleave (each query group is repeated group_size times):
+ # [K1, K1, K2, K2]
+ # If we used regular repeat instead of repeat_interleave, we'd get:
+ # [K1, K2, K1, K2]
+
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
+ # Shape: (b, num_heads, num_tokens, num_tokens)
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
+
+ # Use the mask to fill attention scores
+ num_tokens_Q = queries.shape[-2]
+ num_tokens_K = keys.shape[-2]
+ device = queries.device
+ if use_cache:
+ q_positions = torch.arange(
+ self.ptr_current_pos,
+ self.ptr_current_pos + num_tokens_Q,
+ device=device,
+ dtype=torch.long,
+ )
+ self.ptr_current_pos += num_tokens_Q
+ else:
+ q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long)
+ self.ptr_current_pos = 0
+ k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long)
+ mask = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0)
+
+ attn_scores = attn_scores.masked_fill(mask, -torch.inf)
+
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
+ assert keys.shape[-1] == self.head_dim
+ attn_weights = self.dropout(attn_weights)
+
+ # Shape: (b, num_tokens, num_heads, head_dim)
+ context_vec = (attn_weights @ values).transpose(1, 2)
+
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
+ context_vec = context_vec.reshape(b, num_tokens, self.d_out)
+ context_vec = self.out_proj(context_vec) # optional projection
+
+ return context_vec
+
+ def reset_cache(self):
+ self.cache_k, self.cache_v = None, None
+ self.ptr_current_pos = 0
+
+
+#####################################
+# Chapter 4
+#####################################
+class LayerNorm(nn.Module):
+ def __init__(self, emb_dim):
+ super().__init__()
+ self.eps = 1e-5
+ self.scale = nn.Parameter(torch.ones(emb_dim))
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdim=True)
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
+ return self.scale * norm_x + self.shift
+
+
+class GELU(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return 0.5 * x * (1 + torch.tanh(
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
+ (x + 0.044715 * torch.pow(x, 3))
+ ))
+
+
+class FeedForward(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
+ GELU(),
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.att = GroupedQueryAttention(
+ d_in=cfg["emb_dim"],
+ d_out=cfg["emb_dim"],
+ num_heads=cfg["n_heads"],
+ num_kv_groups=cfg["n_kv_groups"],
+ dropout=cfg["drop_rate"],
+ qkv_bias=cfg["qkv_bias"])
+ self.ff = FeedForward(cfg)
+ self.norm1 = LayerNorm(cfg["emb_dim"])
+ self.norm2 = LayerNorm(cfg["emb_dim"])
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
+
+ def forward(self, x, use_cache=False):
+ # Shortcut connection for attention block
+ shortcut = x
+ x = self.norm1(x)
+
+ # x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
+ ####################################################
+ # NEW
+ x = self.att(x, use_cache=use_cache)
+ ####################################################
+
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ # Shortcut connection for feed-forward block
+ shortcut = x
+ x = self.norm2(x)
+ x = self.ff(x)
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ return x
+
+
+class GPTModel(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
+ self.drop_emb = nn.Dropout(cfg["drop_rate"])
+
+ # self.trf_blocks = nn.Sequential(
+ # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+ ####################################################
+ # NEW
+ self.trf_blocks = nn.ModuleList(
+ [TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+
+ self.current_pos = 0
+ ####################################################
+
+ self.final_norm = LayerNorm(cfg["emb_dim"])
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
+
+ def forward(self, in_idx, use_cache=False):
+ batch_size, seq_len = in_idx.shape
+ tok_embeds = self.tok_emb(in_idx)
+
+ # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
+
+ ####################################################
+ # NEW
+
+ if use_cache:
+ pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
+ self.current_pos += seq_len
+ else:
+ pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
+ pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
+ ####################################################
+
+ x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
+ x = self.drop_emb(x)
+
+ # x = self.trf_blocks(x)
+ ####################################################
+ # NEW
+ for blk in self.trf_blocks:
+ x = blk(x, use_cache=use_cache)
+ ####################################################
+
+ x = self.final_norm(x)
+ logits = self.out_head(x)
+ return logits
+
+ ####################################################
+ # NEW
+ def reset_kv_cache(self):
+ for blk in self.trf_blocks:
+ blk.att.reset_cache()
+ self.current_pos = 0
+ ####################################################
+
+
+def generate_text_simple(model, idx, max_new_tokens, context_size):
+ # idx is (B, T) array of indices in the current context
+ for _ in range(max_new_tokens):
+
+ # Crop current context if it exceeds the supported context size
+ # E.g., if LLM supports only 5 tokens, and the context size is 10
+ # then only the last 5 tokens are used as context
+ idx_cond = idx[:, -context_size:]
+
+ # Get the predictions
+ with torch.no_grad():
+ logits = model(idx_cond)
+
+ # Focus only on the last time step
+ # (batch, n_token, vocab_size) becomes (batch, vocab_size)
+ logits = logits[:, -1, :]
+
+ # Get the idx of the vocab entry with the highest logits value
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
+
+ # Append sampled index to the running sequence
+ idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
+
+ return idx
+
+
+####################################################
+# NEW
+def generate_text_simple_cached(model, idx, max_new_tokens,
+ context_size=None, use_cache=True):
+ model.eval()
+ ctx_len = context_size or model.pos_emb.num_embeddings
+
+ with torch.no_grad():
+ if use_cache:
+ # Init cache with full prompt
+ model.reset_kv_cache()
+ logits = model(idx[:, -ctx_len:], use_cache=True)
+
+ for _ in range(max_new_tokens):
+ # a) pick the token with the highest log-probability (greedy sampling)
+ next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+ # b) append it to the running sequence
+ idx = torch.cat([idx, next_idx], dim=1)
+ # c) feed model only the new token
+ logits = model(next_idx, use_cache=True)
+ else:
+ for _ in range(max_new_tokens):
+ logits = model(idx[:, -ctx_len:], use_cache=False)
+ next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+ idx = torch.cat([idx, next_idx], dim=1)
+
+ return idx
+####################################################
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Run GPT with grouped-query attention.")
+ parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
+ parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
+ parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")
+ parser.add_argument("--n_kv_groups", type=int, default=2, help="Number of key/value groups.")
+ parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.")
+ args = parser.parse_args()
+
+ start_context = "Hello, I am"
+ tokenizer = tiktoken.get_encoding("gpt2")
+ encoded = tokenizer.encode(start_context)
+
+ GPT_CONFIG_124M = {
+ "vocab_size": 50257, # Vocabulary size
+ "context_length": args.max_new_tokens + len(encoded),
+ "emb_dim": args.emb_dim, # Embedding dimension
+ "n_heads": args.n_heads, # Number of attention heads
+ "n_layers": args.n_layers, # Number of layers
+ "drop_rate": 0.0, # Dropout rate
+ "qkv_bias": False, # Query-Key-Value bias
+ "n_kv_groups": args.n_kv_groups
+ }
+ torch.manual_seed(123)
+ model = GPTModel(GPT_CONFIG_124M)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.to(device, dtype=torch.bfloat16)
+ model.eval() # disable dropout
+
+ encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
+ print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
+ print("\nInput text:", start_context)
+ print("Encoded input text:", encoded)
+ print("encoded_tensor.shape:", encoded_tensor.shape)
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ start = time.time()
+
+ token_ids = generate_text_simple_cached(
+ model=model,
+ idx=encoded_tensor,
+ max_new_tokens=args.max_new_tokens,
+ )
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ total_time = time.time() - start
+
+ decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
+
+ print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
+ print("\nOutput:", token_ids)
+ print("Output length:", len(token_ids[0]))
+ print("Output text:", decoded_text)
+
+ print(f"\nTime: {total_time:.2f} sec")
+ print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
+ if torch.cuda.is_available():
+ max_mem_bytes = torch.cuda.max_memory_allocated()
+ max_mem_gb = max_mem_bytes / (1024 ** 3)
+ print(f"Max memory allocated: {max_mem_gb:.2f} GB")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ch04/04_gqa/gpt_with_kv_mha.py b/ch04/04_gqa/gpt_with_kv_mha.py
new file mode 100644
index 0000000..6907a69
--- /dev/null
+++ b/ch04/04_gqa/gpt_with_kv_mha.py
@@ -0,0 +1,376 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+# This file collects all the relevant code that we covered thus far
+# throughout Chapters 3-4.
+# This file can be run as a standalone script.
+
+import argparse
+import time
+import tiktoken
+import torch
+import torch.nn as nn
+
+
+#####################################
+# Chapter 3
+#####################################
+class MultiHeadAttention(nn.Module):
+ def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False):
+ super().__init__()
+ assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
+
+ self.d_out = d_out
+ self.num_heads = num_heads
+ self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
+
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
+ self.dropout = nn.Dropout(dropout)
+
+ ####################################################
+ # NEW
+ self.register_buffer("cache_k", None, persistent=False)
+ self.register_buffer("cache_v", None, persistent=False)
+ self.ptr_current_pos = 0
+ ####################################################
+
+ def forward(self, x, use_cache=False):
+ b, num_tokens, d_in = x.shape
+
+ keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
+ values_new = self.W_value(x)
+ queries = self.W_query(x)
+
+ # We implicitly split the matrix by adding a `num_heads` dimension
+ # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
+ keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
+ values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
+
+ ####################################################
+ # NEW
+ if use_cache:
+ if self.cache_k is None:
+ self.cache_k, self.cache_v = keys_new, values_new
+ else:
+ self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
+ self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
+ keys, values = self.cache_k, self.cache_v
+ else:
+ keys, values = keys_new, values_new
+ ####################################################
+
+ # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
+ keys = keys.transpose(1, 2)
+ queries = queries.transpose(1, 2)
+ values = values.transpose(1, 2)
+
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
+
+ ####################################################
+ # NEW
+ num_tokens_Q = queries.shape[-2]
+ num_tokens_K = keys.shape[-2]
+ device = queries.device
+ if use_cache:
+ q_positions = torch.arange(
+ self.ptr_current_pos,
+ self.ptr_current_pos + num_tokens_Q,
+ device=device,
+ dtype=torch.long,
+ )
+ self.ptr_current_pos += num_tokens_Q
+ else:
+ q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long)
+ self.ptr_current_pos = 0
+ k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long)
+ mask_bool = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0)
+
+ # Use the mask to fill attention scores
+ attn_scores.masked_fill_(mask_bool, -torch.inf)
+
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
+ attn_weights = self.dropout(attn_weights)
+
+ # Shape: (b, num_tokens, num_heads, head_dim)
+ context_vec = (attn_weights @ values).transpose(1, 2)
+
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
+ context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
+ context_vec = self.out_proj(context_vec) # optional projection
+
+ return context_vec
+
+ ####################################################
+ # NEW
+ def reset_cache(self):
+ self.cache_k, self.cache_v = None, None
+ self.ptr_current_pos = 0
+ ####################################################
+
+
+#####################################
+# Chapter 4
+#####################################
+class LayerNorm(nn.Module):
+ def __init__(self, emb_dim):
+ super().__init__()
+ self.eps = 1e-5
+ self.scale = nn.Parameter(torch.ones(emb_dim))
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdim=True)
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
+ return self.scale * norm_x + self.shift
+
+
+class GELU(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return 0.5 * x * (1 + torch.tanh(
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
+ (x + 0.044715 * torch.pow(x, 3))
+ ))
+
+
+class FeedForward(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
+ GELU(),
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.att = MultiHeadAttention(
+ d_in=cfg["emb_dim"],
+ d_out=cfg["emb_dim"],
+ num_heads=cfg["n_heads"],
+ dropout=cfg["drop_rate"],
+ qkv_bias=cfg["qkv_bias"])
+ self.ff = FeedForward(cfg)
+ self.norm1 = LayerNorm(cfg["emb_dim"])
+ self.norm2 = LayerNorm(cfg["emb_dim"])
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
+
+ def forward(self, x, use_cache=False):
+ # Shortcut connection for attention block
+ shortcut = x
+ x = self.norm1(x)
+
+ # x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
+ ####################################################
+ # NEW
+ x = self.att(x, use_cache=use_cache)
+ ####################################################
+
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ # Shortcut connection for feed-forward block
+ shortcut = x
+ x = self.norm2(x)
+ x = self.ff(x)
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ return x
+
+
+class GPTModel(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
+ self.drop_emb = nn.Dropout(cfg["drop_rate"])
+
+ # self.trf_blocks = nn.Sequential(
+ # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+ ####################################################
+ # NEW
+ self.trf_blocks = nn.ModuleList(
+ [TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+
+ self.current_pos = 0
+ ####################################################
+
+ self.final_norm = LayerNorm(cfg["emb_dim"])
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
+
+ def forward(self, in_idx, use_cache=False):
+ batch_size, seq_len = in_idx.shape
+ tok_embeds = self.tok_emb(in_idx)
+
+ # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
+
+ ####################################################
+ # NEW
+
+ if use_cache:
+ pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
+ self.current_pos += seq_len
+ else:
+ pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
+ pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
+ ####################################################
+
+ x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
+ x = self.drop_emb(x)
+
+ # x = self.trf_blocks(x)
+ ####################################################
+ # NEW
+ for blk in self.trf_blocks:
+ x = blk(x, use_cache=use_cache)
+ ####################################################
+
+ x = self.final_norm(x)
+ logits = self.out_head(x)
+ return logits
+
+ ####################################################
+ # NEW
+ def reset_kv_cache(self):
+ for blk in self.trf_blocks:
+ blk.att.reset_cache()
+ self.current_pos = 0
+ ####################################################
+
+
+def generate_text_simple(model, idx, max_new_tokens, context_size):
+ # idx is (B, T) array of indices in the current context
+ for _ in range(max_new_tokens):
+
+ # Crop current context if it exceeds the supported context size
+ # E.g., if LLM supports only 5 tokens, and the context size is 10
+ # then only the last 5 tokens are used as context
+ idx_cond = idx[:, -context_size:]
+
+ # Get the predictions
+ with torch.no_grad():
+ logits = model(idx_cond)
+
+ # Focus only on the last time step
+ # (batch, n_token, vocab_size) becomes (batch, vocab_size)
+ logits = logits[:, -1, :]
+
+ # Get the idx of the vocab entry with the highest logits value
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
+
+ # Append sampled index to the running sequence
+ idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
+
+ return idx
+
+
+####################################################
+# NEW
+def generate_text_simple_cached(model, idx, max_new_tokens,
+ context_size=None, use_cache=True):
+ model.eval()
+ ctx_len = context_size or model.pos_emb.num_embeddings
+
+ with torch.no_grad():
+ if use_cache:
+ # Init cache with full prompt
+ model.reset_kv_cache()
+ logits = model(idx[:, -ctx_len:], use_cache=True)
+
+ for _ in range(max_new_tokens):
+ # a) pick the token with the highest log-probability (greedy sampling)
+ next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+ # b) append it to the running sequence
+ idx = torch.cat([idx, next_idx], dim=1)
+ # c) feed model only the new token
+ logits = model(next_idx, use_cache=True)
+ else:
+ for _ in range(max_new_tokens):
+ logits = model(idx[:, -ctx_len:], use_cache=False)
+ next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+ idx = torch.cat([idx, next_idx], dim=1)
+
+ return idx
+####################################################
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.")
+ parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
+ parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
+ parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")
+ parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.")
+ args = parser.parse_args()
+
+ start_context = "Hello, I am"
+ tokenizer = tiktoken.get_encoding("gpt2")
+ encoded = tokenizer.encode(start_context)
+
+ GPT_CONFIG_124M = {
+ "vocab_size": 50257, # Vocabulary size
+ "context_length": args.max_new_tokens + len(encoded),
+ "emb_dim": args.emb_dim, # Embedding dimension
+ "n_heads": args.n_heads, # Number of attention heads
+ "n_layers": args.n_layers, # Number of layers
+ "drop_rate": 0.0, # Dropout rate
+ "qkv_bias": False, # Query-Key-Value bias
+ }
+ torch.manual_seed(123)
+ model = GPTModel(GPT_CONFIG_124M)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.to(device, dtype=torch.bfloat16)
+ model.eval() # disable dropout
+
+ encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
+ print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
+ print("\nInput text:", start_context)
+ print("Encoded input text:", encoded)
+ print("encoded_tensor.shape:", encoded_tensor.shape)
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ start = time.time()
+
+ token_ids = generate_text_simple_cached(
+ model=model,
+ idx=encoded_tensor,
+ max_new_tokens=args.max_new_tokens,
+ )
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ total_time = time.time() - start
+
+ decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
+
+ print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
+ print("\nOutput:", token_ids)
+ print("Output length:", len(token_ids[0]))
+ print("Output text:", decoded_text)
+
+ print(f"\nTime: {total_time:.2f} sec")
+ print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
+ if torch.cuda.is_available():
+ max_mem_bytes = torch.cuda.max_memory_allocated()
+ max_mem_gb = max_mem_bytes / (1024 ** 3)
+ print(f"Max memory allocated: {max_mem_gb:.2f} GB")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ch04/04_gqa/memory_estimator.py b/ch04/04_gqa/memory_estimator.py
new file mode 100644
index 0000000..1151acb
--- /dev/null
+++ b/ch04/04_gqa/memory_estimator.py
@@ -0,0 +1,98 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+# KV-cache memory estimator for MHA vs GQA
+
+
+import argparse
+import math
+
+DTYPE_BYTES = {
+ "fp32": 4,
+ "bf16": 2,
+ "fp16": 2,
+ "fp8": 1,
+ "int8": 1,
+}
+
+
+def bytes_convert(n):
+ gb = n / (1000 ** 3)
+ return f"{gb:,.2f} GB"
+
+
+def kv_bytes_total(batch, context_length, emb_dim, n_heads,
+ n_kv_heads, n_layers, bytes_per_elem):
+ head_dim = math.ceil(emb_dim / n_heads)
+ per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
+ return per_layer * n_layers
+
+
+def main():
+ p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA vs GQA")
+ p.add_argument("--context_length", default=1024, type=int)
+ p.add_argument("--emb_dim", required=True, type=int)
+ p.add_argument("--n_heads", required=True, type=int)
+ p.add_argument("--n_layers", required=True, type=int)
+ p.add_argument("--n_kv_groups", required=True, type=int)
+ p.add_argument("--batch_size", default=1, type=int)
+ p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="fp16")
+ args = p.parse_args()
+
+ cfg = {
+ "context_length": args.context_length,
+ "emb_dim": args.emb_dim,
+ "n_heads": args.n_heads,
+ "n_layers": args.n_layers,
+ "n_kv_groups": args.n_kv_groups,
+ }
+
+ bytes_per_elem = DTYPE_BYTES[args.dtype]
+ head_dim = cfg["emb_dim"] / cfg["n_heads"]
+
+ n_kv_heads_mha = cfg["n_heads"]
+ n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
+
+ total_mha = kv_bytes_total(
+ args.batch_size,
+ cfg["context_length"],
+ cfg["emb_dim"],
+ cfg["n_heads"],
+ n_kv_heads_mha,
+ cfg["n_layers"],
+ bytes_per_elem,
+ )
+
+ total_gqa = kv_bytes_total(
+ args.batch_size,
+ cfg["context_length"],
+ cfg["emb_dim"],
+ cfg["n_heads"],
+ n_kv_heads_gqa,
+ cfg["n_layers"],
+ bytes_per_elem,
+ )
+
+ ratio = total_mha / total_gqa
+ savings = 1 - (total_gqa / total_mha)
+
+ print("==== Config ====")
+ for k, v in cfg.items():
+ print(f"{k:17}: {v}")
+ print(f"batch_size : {args.batch_size}")
+ print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)")
+ print(f"head_dim : {int(head_dim)}")
+ print(f"GQA n_kv_heads : {n_kv_heads_gqa}")
+ print()
+
+ print("==== KV-cache totals across all layers ====")
+ print(f"MHA total KV cache : {bytes_convert(total_mha)}")
+ print(f"GQA total KV cache : {bytes_convert(total_gqa)}")
+ print(f"Ratio (MHA / GQA) : {ratio:,.2f}x")
+ print(f"Savings (GQA vs MHA): {savings*100:,.2f}%")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ch04/04_gqa/plot_memory_estimates.py b/ch04/04_gqa/plot_memory_estimates.py
new file mode 100644
index 0000000..5476c1e
--- /dev/null
+++ b/ch04/04_gqa/plot_memory_estimates.py
@@ -0,0 +1,125 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+# Plot KV-cache
+
+
+import matplotlib.pyplot as plt
+
+# Import from ./memory_estimator.py
+from memory_estimator import kv_bytes_total, DTYPE_BYTES
+
+
+def savings_percent(total_mha, total_gqa):
+ return (1.0 - (total_gqa / total_mha)) * 100.0
+
+
+def plot_savings_vs_nkvgroups():
+ n_heads = 24
+ emb_dim = 2048
+ n_layers = 48
+ batch_size = 1
+ context_length = 8192
+ dtype = "bf16"
+ bytes_per_elem = DTYPE_BYTES[dtype]
+
+ total_mha = kv_bytes_total(
+ batch_size,
+ context_length,
+ emb_dim,
+ n_heads,
+ n_heads,
+ n_layers,
+ bytes_per_elem,
+ )
+
+ xs = []
+ ys = []
+ for n_kv_groups in range(1, n_heads + 1):
+ n_kv_heads = n_heads // n_kv_groups
+ total_gqa = kv_bytes_total(
+ batch_size,
+ context_length,
+ emb_dim,
+ n_heads,
+ n_kv_heads,
+ n_layers,
+ bytes_per_elem,
+ )
+ xs.append(n_kv_groups)
+ ys.append(savings_percent(total_mha, total_gqa))
+
+ plt.figure()
+ plt.plot(xs, ys, marker="o")
+ plt.xlabel("n_kv_groups")
+ plt.ylabel("Savings vs MHA (%)")
+ plt.title(
+ "KV-cache Savings vs n_kv_groups\n"
+ "(n_heads=24, emb_dim=2048, n_layers=48, "
+ "batch=1, context=8192, dtype=bf16)", fontsize=8
+ )
+ plt.grid(True)
+ plt.tight_layout()
+ plt.savefig("savings_vs_n_kv_groups.pdf")
+
+
+def plot_abs_kv_vs_context():
+ n_heads = 24
+ emb_dim = 2048
+ n_layers = 48
+ batch_size = 1
+ n_kv_groups = 4
+ dtype = "bf16"
+ bytes_per_elem = DTYPE_BYTES[dtype]
+
+ n_kv_heads_mha = n_heads
+ n_kv_heads_gqa = n_heads // n_kv_groups
+
+ context_lengths = [
+ 256, 512, 1024, 2048, 4096, 8192,
+ 16384, 32768, 65536, 131072
+ ]
+
+ xs = []
+ mha_gib = []
+ gqa_gib = []
+ savings_pct = None
+
+ for L in context_lengths:
+ total_mha = kv_bytes_total(
+ batch_size, L, emb_dim, n_heads,
+ n_kv_heads_mha, n_layers, bytes_per_elem
+ )
+ total_gqa = kv_bytes_total(
+ batch_size, L, emb_dim, n_heads,
+ n_kv_heads_gqa, n_layers, bytes_per_elem
+ )
+ xs.append(L)
+ mha_gib.append(total_mha / (1024**3))
+ gqa_gib.append(total_gqa / (1024**3))
+ if savings_pct is None:
+ savings_pct = savings_percent(total_mha, total_gqa)
+
+ plt.figure()
+ plt.plot(xs, mha_gib, marker="o", label="MHA (KV total)")
+ plt.plot(xs, gqa_gib, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
+ plt.xscale("log")
+ plt.xlabel("context_length (log scale)")
+ plt.ylabel("Total KV cache (GB)")
+ plt.title(
+ "KV-cache vs Context Length\n"
+ "(n_heads=24, emb_dim=2048, n_layers=48, "
+ "batch=1, n_kv_groups=4, dtype=bf16)", fontsize=8
+ )
+ plt.grid(True, which="both")
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig("kv_bytes_vs_context_length.pdf")
+ print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.")
+
+
+if __name__ == "__main__":
+ plot_savings_vs_nkvgroups()
+ plot_abs_kv_vs_context()