diff --git a/.gitignore b/.gitignore index 5a7d6d6..d83f522 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,7 @@ appendix-D/01_main-chapter-code/3.pdf appendix-E/01_main-chapter-code/loss-plot.pdf ch04/04_gqa/kv_bytes_vs_context_length.pdf -ch04/04_gqa/savings_vs_n_kv_groups.pdf +ch05/05_mla/kv_bytes_vs_context_length.pdf ch05/01_main-chapter-code/loss-plot.pdf ch05/01_main-chapter-code/temperature-plot.pdf diff --git a/README.md b/README.md index f6bd229..bea9371 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,7 @@ Several folders contain optional materials as a bonus for interested readers: - [FLOPS Analysis](ch04/02_performance-analysis/flops-analysis.ipynb) - [KV Cache](ch04/03_kv-cache) - [Grouped-Query Attention](ch04/04_gqa) + - [Multi-Head Latent Attention](ch04/05_mla) - **Chapter 5: Pretraining on unlabeled data:** - [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/) - [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg) diff --git a/ch04/04_gqa/README.md b/ch04/04_gqa/README.md index 100702e..d4d9174 100644 --- a/ch04/04_gqa/README.md +++ b/ch04/04_gqa/README.md @@ -2,12 +2,9 @@ This bonus material illustrates the memory savings when using Grouped-Query Attention (GQA) over regular Multi-Head Attention (MHA). - -   ## Introduction - Grouped-Query Attention (GQA) has become the new standard replacement for a more compute- and parameter-efficient alternative to Multi-Head Attention (MHA) in recent years. Note that it's not new and goes back to the 2023 [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245). And even the larger variants in the good old Llama 2 series used it. Here's a brief GQA summary. Unlike MHA, where each head also has its own set of keys and values, to reduce memory usage, GQA groups multiple heads to share the same key and value projections. @@ -28,8 +25,6 @@ While GQA is mainly a computational-efficiency workaround for MHA, ablation stud However, this assumes that the number of key-value groups is chosen carefully. However, if we set the number of key-value heads equal to the number of heads (this special case is known as multi-query attention), it will negatively affect the modeling performance. - -   ## GQA Memory Savings @@ -37,10 +32,10 @@ The memory savings are mostly reflected in the KV storage. We can compute the KV bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads -You can use the [memory_estimator.py](memory_estimator.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA: +You can use the [memory_estimator_gqa.py](memory_estimator_gqa.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA: ```bash -➜ uv run memory_estimator.py \ +➜ uv run memory_estimator_gqa.py \ --emb_dim 4096 --n_heads 32 --n_layers 32 \ --context_length 32768 --n_kv_groups 4 \ --batch_size 1 --dtype bf16 @@ -62,25 +57,15 @@ Ratio (MHA / GQA) : 4.00x Savings (GQA vs MHA): 75.00% ``` -The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes: +The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes as a function of the context length:   -GQA +GQA   -And the following plot shows how the KV cache size grows with an increasing context length: - -  - -GQA - -  - -You can reproduce these plots via `uv run plot_memory_estimates.py`. - - +You can reproduce the plot via `uv run plot_memory_estimates_gqa.py`.   ## GQA Code Examples diff --git a/ch04/04_gqa/gpt_with_kv_gqa.py b/ch04/04_gqa/gpt_with_kv_gqa.py index 1f830ed..6a38a62 100644 --- a/ch04/04_gqa/gpt_with_kv_gqa.py +++ b/ch04/04_gqa/gpt_with_kv_gqa.py @@ -4,7 +4,7 @@ # Code: https://github.com/rasbt/LLMs-from-scratch # This file collects all the relevant code that we covered thus far -# throughout Chapters 3-4. +# throughout Chapters 3-4, adapted to use Grouped-Query Attention (GQA). # This file can be run as a standalone script. import argparse @@ -83,7 +83,8 @@ class GroupedQueryAttention(nn.Module): # Shape: (b, num_heads, num_tokens, num_tokens) attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head - # Use the mask to fill attention scores + #################################################### + # causal mask num_tokens_Q = queries.shape[-2] num_tokens_K = keys.shape[-2] device = queries.device @@ -101,6 +102,7 @@ class GroupedQueryAttention(nn.Module): k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long) mask = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0) + # Use the mask to fill attention scores attn_scores = attn_scores.masked_fill(mask, -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) @@ -111,7 +113,7 @@ class GroupedQueryAttention(nn.Module): context_vec = (attn_weights @ values).transpose(1, 2) # Combine heads, where self.d_out = self.num_heads * self.head_dim - context_vec = context_vec.reshape(b, num_tokens, self.d_out) + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) context_vec = self.out_proj(context_vec) # optional projection return context_vec @@ -184,7 +186,7 @@ class TransformerBlock(nn.Module): # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] #################################################### - # NEW + # KV cache-related x = self.att(x, use_cache=use_cache) #################################################### @@ -211,7 +213,7 @@ class GPTModel(nn.Module): # self.trf_blocks = nn.Sequential( # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) #################################################### - # NEW + # KV cache-related self.trf_blocks = nn.ModuleList( [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) @@ -228,8 +230,7 @@ class GPTModel(nn.Module): # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) #################################################### - # NEW - + # KV cache-related if use_cache: pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) self.current_pos += seq_len @@ -243,7 +244,7 @@ class GPTModel(nn.Module): # x = self.trf_blocks(x) #################################################### - # NEW + # KV cache-related for blk in self.trf_blocks: x = blk(x, use_cache=use_cache) #################################################### @@ -253,7 +254,7 @@ class GPTModel(nn.Module): return logits #################################################### - # NEW + # KV cache-related def reset_kv_cache(self): for blk in self.trf_blocks: blk.att.reset_cache() @@ -261,34 +262,6 @@ class GPTModel(nn.Module): #################################################### -def generate_text_simple(model, idx, max_new_tokens, context_size): - # idx is (B, T) array of indices in the current context - for _ in range(max_new_tokens): - - # Crop current context if it exceeds the supported context size - # E.g., if LLM supports only 5 tokens, and the context size is 10 - # then only the last 5 tokens are used as context - idx_cond = idx[:, -context_size:] - - # Get the predictions - with torch.no_grad(): - logits = model(idx_cond) - - # Focus only on the last time step - # (batch, n_token, vocab_size) becomes (batch, vocab_size) - logits = logits[:, -1, :] - - # Get the idx of the vocab entry with the highest logits value - idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1) - - # Append sampled index to the running sequence - idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) - - return idx - - -#################################################### -# NEW def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True): model.eval() @@ -314,7 +287,6 @@ def generate_text_simple_cached(model, idx, max_new_tokens, idx = torch.cat([idx, next_idx], dim=1) return idx -#################################################### def main(): @@ -324,6 +296,7 @@ def main(): parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") parser.add_argument("--n_kv_groups", type=int, default=2, help="Number of key/value groups.") parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + args = parser.parse_args() start_context = "Hello, I am" diff --git a/ch04/04_gqa/gpt_with_kv_mha.py b/ch04/04_gqa/gpt_with_kv_mha.py index 6907a69..f906d71 100644 --- a/ch04/04_gqa/gpt_with_kv_mha.py +++ b/ch04/04_gqa/gpt_with_kv_mha.py @@ -33,7 +33,7 @@ class MultiHeadAttention(nn.Module): self.dropout = nn.Dropout(dropout) #################################################### - # NEW + # KV cache-related code self.register_buffer("cache_k", None, persistent=False) self.register_buffer("cache_v", None, persistent=False) self.ptr_current_pos = 0 @@ -53,7 +53,7 @@ class MultiHeadAttention(nn.Module): queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) #################################################### - # NEW + # KV cache-related if use_cache: if self.cache_k is None: self.cache_k, self.cache_v = keys_new, values_new @@ -74,7 +74,7 @@ class MultiHeadAttention(nn.Module): attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head #################################################### - # NEW + # causal mask num_tokens_Q = queries.shape[-2] num_tokens_K = keys.shape[-2] device = queries.device @@ -107,12 +107,9 @@ class MultiHeadAttention(nn.Module): return context_vec - #################################################### - # NEW def reset_cache(self): self.cache_k, self.cache_v = None, None self.ptr_current_pos = 0 - #################################################### ##################################### @@ -177,7 +174,7 @@ class TransformerBlock(nn.Module): # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] #################################################### - # NEW + # KV cache-related x = self.att(x, use_cache=use_cache) #################################################### @@ -204,7 +201,7 @@ class GPTModel(nn.Module): # self.trf_blocks = nn.Sequential( # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) #################################################### - # NEW + # KV cache-related self.trf_blocks = nn.ModuleList( [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) @@ -221,8 +218,7 @@ class GPTModel(nn.Module): # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) #################################################### - # NEW - + # KV cache-related if use_cache: pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) self.current_pos += seq_len @@ -236,7 +232,7 @@ class GPTModel(nn.Module): # x = self.trf_blocks(x) #################################################### - # NEW + # KV cache-related for blk in self.trf_blocks: x = blk(x, use_cache=use_cache) #################################################### @@ -246,7 +242,7 @@ class GPTModel(nn.Module): return logits #################################################### - # NEW + # KV cache-related def reset_kv_cache(self): for blk in self.trf_blocks: blk.att.reset_cache() @@ -254,34 +250,6 @@ class GPTModel(nn.Module): #################################################### -def generate_text_simple(model, idx, max_new_tokens, context_size): - # idx is (B, T) array of indices in the current context - for _ in range(max_new_tokens): - - # Crop current context if it exceeds the supported context size - # E.g., if LLM supports only 5 tokens, and the context size is 10 - # then only the last 5 tokens are used as context - idx_cond = idx[:, -context_size:] - - # Get the predictions - with torch.no_grad(): - logits = model(idx_cond) - - # Focus only on the last time step - # (batch, n_token, vocab_size) becomes (batch, vocab_size) - logits = logits[:, -1, :] - - # Get the idx of the vocab entry with the highest logits value - idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1) - - # Append sampled index to the running sequence - idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) - - return idx - - -#################################################### -# NEW def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True): model.eval() @@ -307,7 +275,6 @@ def generate_text_simple_cached(model, idx, max_new_tokens, idx = torch.cat([idx, next_idx], dim=1) return idx -#################################################### def main(): @@ -316,6 +283,7 @@ def main(): parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.") parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + args = parser.parse_args() start_context = "Hello, I am" diff --git a/ch04/04_gqa/memory_estimator.py b/ch04/04_gqa/memory_estimator_gqa.py similarity index 100% rename from ch04/04_gqa/memory_estimator.py rename to ch04/04_gqa/memory_estimator_gqa.py diff --git a/ch04/04_gqa/plot_memory_estimates.py b/ch04/04_gqa/plot_memory_estimates.py deleted file mode 100644 index 6e874d3..0000000 --- a/ch04/04_gqa/plot_memory_estimates.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch - -# Plot KV-cache - - -import matplotlib.pyplot as plt - -# Import from ./memory_estimator.py -from memory_estimator import kv_bytes_total, DTYPE_BYTES - - -def bytes_convert(n): - gb = n / (1000 ** 3) - return f"{gb:.2f}" - - -def savings_percent(total_mha, total_gqa): - return (1.0 - (total_gqa / total_mha)) * 100.0 - - -def plot_savings_vs_nkvgroups(): - n_heads = 24 - emb_dim = 2048 - n_layers = 48 - batch_size = 1 - context_length = 8192 - dtype = "bf16" - bytes_per_elem = DTYPE_BYTES[dtype] - - total_mha = kv_bytes_total( - batch_size, - context_length, - emb_dim, - n_heads, - n_heads, - n_layers, - bytes_per_elem, - ) - - xs = [] - ys = [] - for n_kv_groups in range(1, n_heads + 1): - n_kv_heads = n_heads // n_kv_groups - total_gqa = kv_bytes_total( - batch_size, - context_length, - emb_dim, - n_heads, - n_kv_heads, - n_layers, - bytes_per_elem, - ) - xs.append(n_kv_groups) - ys.append(savings_percent(total_mha, total_gqa)) - - plt.figure() - plt.plot(xs, ys, marker="o") - plt.xlabel("n_kv_groups") - plt.ylabel("Savings vs MHA (%)") - plt.title( - "KV-cache Savings vs n_kv_groups\n" - "(n_heads=24, emb_dim=2048, n_layers=48, " - "batch=1, context=8192, dtype=bf16)", fontsize=8 - ) - plt.grid(True) - plt.tight_layout() - plt.savefig("savings_vs_n_kv_groups.pdf") - - -def plot_abs_kv_vs_context(): - n_heads = 24 - emb_dim = 2048 - n_layers = 48 - batch_size = 1 - n_kv_groups = 4 - dtype = "bf16" - bytes_per_elem = DTYPE_BYTES[dtype] - - n_kv_heads_mha = n_heads - n_kv_heads_gqa = n_heads // n_kv_groups - - context_lengths = [ - 256, 512, 1024, 2048, 4096, 8192, - 16384, 32768, 65536, 131072 - ] - - xs = [] - mha_gb = [] - gqa_gb = [] - savings_pct = None - - for L in context_lengths: - total_mha = kv_bytes_total( - batch_size, L, emb_dim, n_heads, - n_kv_heads_mha, n_layers, bytes_per_elem - ) - total_gqa = kv_bytes_total( - batch_size, L, emb_dim, n_heads, - n_kv_heads_gqa, n_layers, bytes_per_elem - ) - xs.append(L) - mha_gb.append(float(bytes_convert(total_mha))) - gqa_gb.append(float(bytes_convert(total_gqa))) - if savings_pct is None: - savings_pct = savings_percent(total_mha, total_gqa) - - plt.figure() - plt.plot(xs, mha_gb, marker="o", label="MHA (KV total)") - plt.plot(xs, gqa_gb, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})") - plt.xscale("log") - plt.xlabel("context_length (log scale)") - plt.ylabel("Total KV cache (GB)") - plt.title( - "KV-cache vs Context Length\n" - "(n_heads=24, emb_dim=2048, n_layers=48, " - "batch=1, n_kv_groups=4, dtype=bf16)", fontsize=8 - ) - plt.grid(True, which="both") - plt.legend() - plt.tight_layout() - plt.savefig("kv_bytes_vs_context_length.pdf") - print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.") - print(f"Example (context_length={context_lengths[-1]}): " - f"MHA={bytes_convert(total_mha)} GB, " - f"GQA={bytes_convert(total_gqa)} GB") - - -if __name__ == "__main__": - plot_savings_vs_nkvgroups() - plot_abs_kv_vs_context() diff --git a/ch04/04_gqa/plot_memory_estimates_gqa.py b/ch04/04_gqa/plot_memory_estimates_gqa.py new file mode 100644 index 0000000..f114180 --- /dev/null +++ b/ch04/04_gqa/plot_memory_estimates_gqa.py @@ -0,0 +1,81 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# Plot KV-cache vs context length for different n_kv_groups + +import matplotlib.pyplot as plt + +# Import from ./memory_estimator.py +from memory_estimator_gqa import kv_bytes_total, DTYPE_BYTES + + +def bytes_convert(n): + gb = n / (1000 ** 3) + return f"{gb:.2f}" + + +def savings_percent(total_mha, total_gqa): + return (1.0 - (total_gqa / total_mha)) * 100.0 + + +def plot_abs_kv_vs_context_multi_groups(): + n_heads = 24 + emb_dim = 2048 + n_layers = 48 + batch_size = 1 + dtype = "bf16" + bytes_per_elem = DTYPE_BYTES[dtype] + + # x-axis (log scale) + context_lengths = [ + 256, 512, 1024, 2048, 4096, 8192, + 16384, 32768, 65536, 131072 + ] + + mha_gb = [] + for L in context_lengths: + total_mha = kv_bytes_total( + batch_size, L, emb_dim, n_heads, + n_heads, # MHA: n_kv_heads = n_heads + n_layers, bytes_per_elem + ) + mha_gb.append(float(bytes_convert(total_mha))) + + plt.figure() + plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)") + + # GQA curves for selected n_kv_groups + groups_list = [4, 8, 12, 24] + for g in groups_list: + n_kv_heads = n_heads // g + gqa_gb = [] + for L in context_lengths: + total_gqa = kv_bytes_total( + batch_size, L, emb_dim, n_heads, + n_kv_heads, n_layers, bytes_per_elem + ) + gqa_gb.append(float(bytes_convert(total_gqa))) + + # Compression rate relative to MHA + comp = (n_heads / n_kv_heads) if n_kv_heads > 0 else float("inf") + plt.plot(context_lengths, gqa_gb, marker="o", + label=f"GQA (n_kv_groups={g}, {comp:,.1f}× compression)") + + plt.xscale("log") + plt.xlabel("context_length (log scale)") + plt.ylabel("Total KV cache (GB)") + plt.title( + "KV-cache vs Context Length — MHA vs GQA (multi-group)\n" + "(n_heads=24, emb_dim=2048, n_layers=48, batch=1, dtype=bf16)", + fontsize=8 + ) + plt.grid(True, which="both") + plt.legend() + plt.tight_layout() + plt.savefig("kv_bytes_vs_context_length.pdf") + + +if __name__ == "__main__": + plot_abs_kv_vs_context_multi_groups() diff --git a/ch04/05_mla/README.md b/ch04/05_mla/README.md new file mode 100644 index 0000000..25d0e19 --- /dev/null +++ b/ch04/05_mla/README.md @@ -0,0 +1,142 @@ +# Multi-Head Latent Attention (MLA) + +This bonus material illustrates the memory savings when using Multi-Head Latent Attention (MLA) over regular Multi-Head Attention (MHA). + +  +## Introduction + +In [../04_gqa](../04_gqa), we discussed Grouped-Query Attention (GQA) as a computational-efficiency workaround for MHA. And ablation studies (such as those in the[ original GQA paper](https://arxiv.org/abs/2305.13245) and the [Llama 2 paper](https://arxiv.org/abs/2307.09288)) show it performs comparably to standard MHA in terms of LLM modeling performance. + +Now, Multi-Head Latent Attention (MLA), which is used in [DeepSeek V2, V3, and R1](https://arxiv.org/abs/2412.19437), offers a different memory-saving strategy that also pairs particularly well with KV caching. Instead of sharing key and value heads like GQA, MLA compresses the key and value tensors into a lower-dimensional space before storing them in the KV cache. + +At inference time, these compressed tensors are projected back to their original size before being used, as shown in the figure below. This adds an extra matrix multiplication but reduces memory usage. + +  + +![MLA](https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/mla-memory/1.webp) + +  + +(As a side note, the queries are also compressed, but only during training, not inference.) + +By the way, as mentioned earlier, MLA is not new in DeepSeek V3, as its [DeepSeek V2 predecessor](https://arxiv.org/abs/2405.04434) also used (and even introduced) it. Also, the V2 paper contains a few interesting ablation studies that may explain why the DeepSeek team chose MLA over GQA (see the figure below). + +  + +GQA + +  + +As shown in the figure above, GQA appears to perform worse than MHA, whereas MLA offers better modeling performance than MHA, which is likely why the DeepSeek team chose MLA over GQA. (It would have been interesting to see the "KV Cache per Token" savings comparison between MLA and GQA as well!) + +To summarize this section, before we move on to the next architecture component, MLA is a clever trick to reduce KV cache memory use while even slightly outperforming MHA in terms of modeling performance. + +  +## MLA Memory Savings + +The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula: + +bytes ≈ batch_size × seqlen × n_layers × latent_dim × bytes_per_elem + +In contrast, MHA KV cache memory is computed as follows: + +bytes ≈ batch_size × seqlen × n_layers × embed_dim × 2 (K,V) × bytes_per_elem + +This means, in MLA, we reduce "embed_dim × 2 (K,V)" to "latent_dim", since we only stored the compressed latent representation instead of the full key and value vectors as shown in the earlier figure above. + + + +You can use the [memory_estimator_mla.py](memory_estimator_mla.py) script in this folder to apply this for different model configs to see how much memory you can save by using MLA over MHA: + +```bash +➜ uv run memory_estimator_mla.py \ + --context_length 8192 \ + --emb_dim 2048 \ + --n_heads 24 \ + --n_layers 48 \ + --n_kv_groups 4 \ + --batch_size 1 \ + --dtype bf16 \ + --latent_dim 1024 +==== Config ==== +context_length : 8192 +emb_dim : 2048 +n_heads : 24 +n_layers : 48 +n_kv_groups : 4 +latent_dim : 1024 +batch_size : 1 +dtype : bf16 (2 Bytes/elem) +head_dim : 86 +GQA n_kv_heads : 6 + +==== KV-cache totals across all layers ==== +MHA total KV cache : 3.25 GB +GQA total KV cache : 0.81 GB +MLA total KV cache : 0.81 GB +Ratio (MHA / GQA) : 4.00x +Savings (GQA vs MHA): 75.00% +Ratio (MHA / MLA) : 4.03x +Savings (MLA vs MHA): 75.19% +``` + +Note that the compression above (`--emb_dim 2048 -> latent_dim 1024`) to achieve a similar saving as for GQA. In practice, the compression is a hyperparameter that needs to be carefully investigated, as choosing `latent_dim` to be too small can have negative impact on the modeling performance (similar to choosing too many `n_kv_groups` in GQA). + +The savings when using MLA over MHA are further shown in the plot below for different `latent_dim` values as a function of the context length: + +  + +GQA + +  + +You can reproduce the plot via `uv run plot_memory_estimates_mla.py`. + + + +  +## MLA Code Examples + +The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_mla.py](gpt_with_kv_mla.py) scripts in this folder provide hands-on examples for comparing the MHA and MLA memory usage in the context of a GPT model implementation. + +Here, the MLA code is inspired by the [https://huggingface.co/bird-of-paradise/deepseek-mla](https://huggingface.co/bird-of-paradise/deepseek-mla) implementation. + +Note that MLA can also be used in combination with GQA, but for simplicity, I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.) + +Also note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it. + +Lastly, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache) so the memory savings are more pronounced. + +```bash +uv run gpt_with_kv_mha.py \ +--max_new_tokens 32768 \ +--n_heads 24 \ +--n_layers 12 \ +--emb_dim 768 + +... + +Time: 453.81 sec +72 tokens/sec +Max memory allocated: 1.54 GB +``` + +```bash +uv run gpt_with_kv_mla.py \ +--max_new_tokens 32768 \ +--n_heads 24 \ +--n_layers 12 \ +--emb_dim 768 \ +--latent_dim 192 # 4x compression + +... + +Time: 487.21 sec +67 tokens/sec +Max memory allocated: 0.68 GB +``` + +The reason why we are not seeing such a big saving as in the plots above is 2-fold: + +1. I use a smaller configuration to have the model finish the generation in a reasonable time. +2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis). diff --git a/ch04/05_mla/gpt_with_kv_mha.py b/ch04/05_mla/gpt_with_kv_mha.py new file mode 100644 index 0000000..92e06f1 --- /dev/null +++ b/ch04/05_mla/gpt_with_kv_mha.py @@ -0,0 +1,344 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# This file collects all the relevant code that we covered thus far +# throughout Chapters 3-4. +# This file can be run as a standalone script. + +import argparse +import time +import tiktoken +import torch +import torch.nn as nn + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + + #################################################### + # KV cache-related code + self.register_buffer("cache_k", None, persistent=False) + self.register_buffer("cache_v", None, persistent=False) + self.ptr_current_pos = 0 + #################################################### + + def forward(self, x, use_cache=False): + b, num_tokens, d_in = x.shape + + keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) + values_new = self.W_value(x) + queries = self.W_query(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim) + values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + #################################################### + # KV cache-related + if use_cache: + if self.cache_k is None: + self.cache_k, self.cache_v = keys_new, values_new + else: + self.cache_k = torch.cat([self.cache_k, keys_new], dim=1) + self.cache_v = torch.cat([self.cache_v, values_new], dim=1) + keys, values = self.cache_k, self.cache_v + else: + keys, values = keys_new, values_new + #################################################### + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + #################################################### + # causal mask + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + device = queries.device + if use_cache: + q_positions = torch.arange( + self.ptr_current_pos, + self.ptr_current_pos + num_tokens_Q, + device=device, + dtype=torch.long, + ) + self.ptr_current_pos += num_tokens_Q + else: + q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long) + self.ptr_current_pos = 0 + k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long) + mask_bool = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0) + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + def reset_cache(self): + self.cache_k, self.cache_v = None, None + self.ptr_current_pos = 0 + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"]) + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x, use_cache=False): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + + # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + #################################################### + # KV cache-related + x = self.att(x, use_cache=use_cache) + #################################################### + + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + # self.trf_blocks = nn.Sequential( + # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + #################################################### + # KV cache-related + self.trf_blocks = nn.ModuleList( + [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.current_pos = 0 + #################################################### + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx, use_cache=False): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + + # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + + #################################################### + # KV cache-related + if use_cache: + pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.current_pos += seq_len + else: + pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) + #################################################### + + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + + # x = self.trf_blocks(x) + #################################################### + # KV cache-related + for blk in self.trf_blocks: + x = blk(x, use_cache=use_cache) + #################################################### + + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + #################################################### + # KV cache-related + def reset_kv_cache(self): + for blk in self.trf_blocks: + blk.att.reset_cache() + self.current_pos = 0 + #################################################### + + +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): + model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings + + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + logits = model(idx[:, -ctx_len:], use_cache=True) + + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token + logits = model(next_idx, use_cache=True) + else: + for _ in range(max_new_tokens): + logits = model(idx[:, -ctx_len:], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) + + return idx + + +def main(): + parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.") + parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.") + parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.") + parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") + parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + + args = parser.parse_args() + + start_context = "Hello, I am" + tokenizer = tiktoken.get_encoding("gpt2") + encoded = tokenizer.encode(start_context) + + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": args.max_new_tokens + len(encoded), + "emb_dim": args.emb_dim, # Embedding dimension + "n_heads": args.n_heads, # Number of attention heads + "n_layers": args.n_layers, # Number of layers + "drop_rate": 0.0, # Dropout rate + "qkv_bias": False, # Query-Key-Value bias + } + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device, dtype=torch.bfloat16) + model.eval() # disable dropout + + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") + print("\nInput text:", start_context) + print("Encoded input text:", encoded) + print("encoded_tensor.shape:", encoded_tensor.shape) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=args.max_new_tokens, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_time = time.time() - start + + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) + + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") + print("\nOutput:", token_ids) + print("Output length:", len(token_ids[0])) + print("Output text:", decoded_text) + + print(f"\nTime: {total_time:.2f} sec") + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") + if torch.cuda.is_available(): + max_mem_bytes = torch.cuda.max_memory_allocated() + max_mem_gb = max_mem_bytes / (1024 ** 3) + print(f"Max memory allocated: {max_mem_gb:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/ch04/05_mla/gpt_with_kv_mla.py b/ch04/05_mla/gpt_with_kv_mla.py new file mode 100644 index 0000000..6e9c388 --- /dev/null +++ b/ch04/05_mla/gpt_with_kv_mla.py @@ -0,0 +1,355 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# This file collects all the relevant code that we covered thus far +# throughout Chapters 3-4, adapted to use Multi-Head Latent Attention (MLA). +# This file can be run as a standalone script. + +import argparse +import time +import tiktoken +import torch +import torch.nn as nn + + +##################################### +# Multi-Head Latent Attention +##################################### +# The MLA code below is inspired by +# https://huggingface.co/bird-of-paradise/deepseek-mla + + +class MultiHeadLatentAttention(nn.Module): + def __init__(self, d_in, d_out, dropout, num_heads, + qkv_bias=False, latent_dim=None): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads + self.latent_dim = latent_dim if latent_dim is not None else max(16, d_out // 8) + + # Projections + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # per-head Q + self.W_DKV = nn.Linear(d_in, self.latent_dim, bias=qkv_bias) # down to latent C + self.W_UK = nn.Linear(self.latent_dim, d_out, bias=qkv_bias) # latent -> per-head K + self.W_UV = nn.Linear(self.latent_dim, d_out, bias=qkv_bias) # latent -> per-head V + + self.out_proj = nn.Linear(d_out, d_out) + self.dropout = nn.Dropout(dropout) + + #################################################### + # Latent-KV cache + self.register_buffer("cache_c_kv", None, persistent=False) + self.ptr_current_pos = 0 + #################################################### + + def reset_cache(self): + self.cache_c_kv = None + self.ptr_current_pos = 0 + + @staticmethod + def _reshape_to_heads(x, num_heads, head_dim): + # (b, T, d_out) -> (b, num_heads, T, head_dim) + bsz, num_tokens, _ = x.shape + return x.view(bsz, num_tokens, num_heads, head_dim).transpose(1, 2).contiguous() + + def forward(self, x, use_cache=False): + b, num_tokens, _ = x.shape + num_heads = self.num_heads + head_dim = self.head_dim + + # 1) Project to queries (per-token, per-head) and new latent chunk + queries_all = self.W_query(x) # (b, T, d_out) + latent_new = self.W_DKV(x) # (b, T, latent_dim) + + # 2) Update latent cache and choose latent sequence to up-project + if use_cache: + if self.cache_c_kv is None: + latent_total = latent_new + else: + latent_total = torch.cat([self.cache_c_kv, latent_new], dim=1) + self.cache_c_kv = latent_total + else: + latent_total = latent_new + + # 3) Up-project latent to per-head keys/values (then split into heads) + keys_all = self.W_UK(latent_total) # (b, T_k_total, d_out) + values_all = self.W_UV(latent_total) # (b, T_k_total, d_out) + + # 4) Reshape to heads + queries = self._reshape_to_heads(queries_all, num_heads, head_dim) + keys = self._reshape_to_heads(keys_all, num_heads, head_dim) + values = self._reshape_to_heads(values_all, num_heads, head_dim) + + # 5) Scaled dot-product attention with causal mask + attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) + + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + device = queries.device + if use_cache: + q_positions = torch.arange( + self.ptr_current_pos, + self.ptr_current_pos + num_tokens_Q, + device=device, + dtype=torch.long, + ) + self.ptr_current_pos += num_tokens_Q + else: + q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long) + self.ptr_current_pos = 0 + k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long) + mask_bool = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0) + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadLatentAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"], + latent_dim=cfg["latent_dim"]) + + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x, use_cache=False): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + + # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + #################################################### + # KV cache-related + x = self.att(x, use_cache=use_cache) + #################################################### + + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + # self.trf_blocks = nn.Sequential( + # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + #################################################### + # KV cache-related + self.trf_blocks = nn.ModuleList( + [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.current_pos = 0 + #################################################### + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx, use_cache=False): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + + # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + + #################################################### + # KV cache-related + if use_cache: + pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.current_pos += seq_len + else: + pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) + #################################################### + + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + + # x = self.trf_blocks(x) + #################################################### + # KV cache-related + for blk in self.trf_blocks: + x = blk(x, use_cache=use_cache) + #################################################### + + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + #################################################### + # KV cache-related + def reset_kv_cache(self): + for blk in self.trf_blocks: + blk.att.reset_cache() + self.current_pos = 0 + #################################################### + + +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): + model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings + + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + logits = model(idx[:, -ctx_len:], use_cache=True) + + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token + logits = model(next_idx, use_cache=True) + else: + for _ in range(max_new_tokens): + logits = model(idx[:, -ctx_len:], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) + + return idx + + +def main(): + parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.") + parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.") + parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.") + parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") + parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + parser.add_argument("--latent_dim", type=int, default=None, + help="Latent dim for MLA (default: d_out//8)") + + args = parser.parse_args() + + start_context = "Hello, I am" + tokenizer = tiktoken.get_encoding("gpt2") + encoded = tokenizer.encode(start_context) + + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": args.max_new_tokens + len(encoded), + "emb_dim": args.emb_dim, # Embedding dimension + "n_heads": args.n_heads, # Number of attention heads + "n_layers": args.n_layers, # Number of layers + "drop_rate": 0.0, # Dropout rate + "qkv_bias": False, # Query-Key-Value bias + "latent_dim": args.latent_dim, + } + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device, dtype=torch.bfloat16) + model.eval() # disable dropout + + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") + print("\nInput text:", start_context) + print("Encoded input text:", encoded) + print("encoded_tensor.shape:", encoded_tensor.shape) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=args.max_new_tokens, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_time = time.time() - start + + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) + + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") + print("\nOutput:", token_ids) + print("Output length:", len(token_ids[0])) + print("Output text:", decoded_text) + + print(f"\nTime: {total_time:.2f} sec") + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") + if torch.cuda.is_available(): + max_mem_bytes = torch.cuda.max_memory_allocated() + max_mem_gb = max_mem_bytes / (1024 ** 3) + print(f"Max memory allocated: {max_mem_gb:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/ch04/05_mla/kv_bytes_vs_context_length.pdf b/ch04/05_mla/kv_bytes_vs_context_length.pdf new file mode 100644 index 0000000..7fb59e4 Binary files /dev/null and b/ch04/05_mla/kv_bytes_vs_context_length.pdf differ diff --git a/ch04/05_mla/memory_estimator_mla.py b/ch04/05_mla/memory_estimator_mla.py new file mode 100644 index 0000000..f9ab9f5 --- /dev/null +++ b/ch04/05_mla/memory_estimator_mla.py @@ -0,0 +1,123 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# KV-cache memory estimator for MHA vs GQA vs MLA + +import argparse +import math + +DTYPE_BYTES = { + "fp32": 4, + "bf16": 2, + "fp16": 2, + "fp8": 1, + "int8": 1, +} + + +def bytes_convert(n): + gb = n / (1000 ** 3) + return f"{gb:,.2f} GB" + + +def kv_bytes_total(batch, context_length, emb_dim, n_heads, + n_kv_heads, n_layers, bytes_per_elem): + # Generic KV-cache: per-head dim is embed_dim / n_heads, times 2 for K and V + head_dim = math.ceil(emb_dim / n_heads) + per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem + return per_layer * n_layers + + +def mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem): + # Simple MLA (per-token compressed latent) + # bytes ≈ batch × seqlen × n_layers × latent_dim × bytes_per_elem + return batch * context_length * n_layers * latent_dim * bytes_per_elem + + +def main(): + p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA vs GQA vs MLA") + p.add_argument("--context_length", default=1024, type=int) + p.add_argument("--emb_dim", required=True, type=int) + p.add_argument("--n_heads", required=True, type=int) + p.add_argument("--n_layers", required=True, type=int) + p.add_argument("--n_kv_groups", required=True, type=int) + p.add_argument("--latent_dim", required=True, type=int, help="MLA per-token latent dimension") + p.add_argument("--batch_size", default=1, type=int) + p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="fp16") + args = p.parse_args() + + cfg = { + "context_length": args.context_length, + "emb_dim": args.emb_dim, + "n_heads": args.n_heads, + "n_layers": args.n_layers, + "n_kv_groups": args.n_kv_groups, + "latent_dim": args.latent_dim, + } + + if cfg["n_heads"] % cfg["n_kv_groups"] != 0: + raise ValueError("n_kv_groups must divide n_heads exactly.") + + bytes_per_elem = DTYPE_BYTES[args.dtype] + head_dim = math.ceil(cfg["emb_dim"] / cfg["n_heads"]) + + n_kv_heads_mha = cfg["n_heads"] + n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"] + + total_mha = kv_bytes_total( + args.batch_size, + cfg["context_length"], + cfg["emb_dim"], + cfg["n_heads"], + n_kv_heads_mha, + cfg["n_layers"], + bytes_per_elem, + ) + + total_gqa = kv_bytes_total( + args.batch_size, + cfg["context_length"], + cfg["emb_dim"], + cfg["n_heads"], + n_kv_heads_gqa, + cfg["n_layers"], + bytes_per_elem, + ) + + total_mla = mla_bytes_total( + args.batch_size, + cfg["context_length"], + cfg["n_layers"], + cfg["latent_dim"], + bytes_per_elem, + ) + + ratio = total_mha / total_gqa if total_gqa != 0 else float("inf") + savings = 1 - (total_gqa / total_mha) if total_mha != 0 else 0.0 + + ratio_mha_mla = total_mha / total_mla if total_mla != 0 else float("inf") + savings_mla = 1 - (total_mla / total_mha) if total_mha != 0 else 0.0 + + print("==== Config ====") + for k, v in cfg.items(): + print(f"{k:17}: {v}") + print(f"batch_size : {args.batch_size}") + print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)") + print(f"head_dim : {head_dim}") + print(f"GQA n_kv_heads : {n_kv_heads_gqa}") + print() + + print("==== KV-cache totals across all layers ====") + print(f"MHA total KV cache : {bytes_convert(total_mha)}") + print(f"GQA total KV cache : {bytes_convert(total_gqa)}") + print(f"MLA total KV cache : {bytes_convert(total_mla)}") + print(f"Ratio (MHA / GQA) : {ratio:,.2f}x") + print(f"Savings (GQA vs MHA): {savings*100:,.2f}%") + print(f"Ratio (MHA / MLA) : {ratio_mha_mla:,.2f}x") + print(f"Savings (MLA vs MHA): {savings_mla*100:,.2f}%") + + +if __name__ == "__main__": + main() diff --git a/ch04/05_mla/plot_memory_estimates_mla.py b/ch04/05_mla/plot_memory_estimates_mla.py new file mode 100644 index 0000000..e4c4208 --- /dev/null +++ b/ch04/05_mla/plot_memory_estimates_mla.py @@ -0,0 +1,90 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import matplotlib.pyplot as plt + +# Bytes per element +DTYPE_BYTES = { + "fp32": 4, + "bf16": 2, + "fp16": 2, + "fp8": 1, + "int8": 1, +} + + +def bytes_to_gb(n_bytes): + return n_bytes / (1000. ** 3) + + +def kv_bytes_total_mha(batch, context_length, emb_dim, n_heads, + n_layers, bytes_per_elem): + head_dim = emb_dim / n_heads + per_layer = batch * context_length * head_dim * n_heads * 2 * bytes_per_elem + return per_layer * n_layers + + +def kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem): + return batch * context_length * n_layers * latent_dim * bytes_per_elem + + +def plot_abs_kv_vs_context_multiple(): + n_heads = 24 + emb_dim = 2048 + n_layers = 48 + batch_size = 1 + dtype = "bf16" + bytes_per_elem = DTYPE_BYTES[dtype] + + context_lengths = [ + 256, 512, 1024, 2048, 4096, 8192, + 16384, 32768, 65536, 131072 + ] + + mha_gb = [] + for L in context_lengths: + total_mha = kv_bytes_total_mha( + batch_size, L, emb_dim, n_heads, n_layers, bytes_per_elem + ) + mha_gb.append(bytes_to_gb(total_mha)) + + latent_dims = [1024, 512, 256, 64] + plt.figure() + plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)") + + L_ref = context_lengths[-1] + total_mha_ref = kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem) + + for latent_dim in latent_dims: + mla_gb = [] + for L in context_lengths: + total_mla = kv_bytes_total_mla( + batch_size, L, n_layers, latent_dim, bytes_per_elem + ) + mla_gb.append(bytes_to_gb(total_mla)) + + total_mla_ref = kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem) + comp = total_mha_ref / total_mla_ref if total_mla_ref != 0 else float("inf") + + plt.plot(context_lengths, mla_gb, marker="o", + label=f"MLA (latent_dim={latent_dim}, {comp:,.1f}× compression)") + + plt.xscale("log") + plt.xlabel("context_length (log scale)") + plt.ylabel("Total KV cache (GB)") + plt.title( + "KV-cache vs Context Length — MHA vs MLA\n" + f"(n_heads={n_heads}, emb_dim={emb_dim}, n_layers={n_layers}, " + f"batch={batch_size}, dtype={dtype})", + fontsize=8 + ) + plt.grid(True, which="both") + plt.legend() + plt.tight_layout() + plt.savefig("kv_bytes_vs_context_length.pdf") + + +if __name__ == "__main__": + plot_abs_kv_vs_context_multiple() diff --git a/ch04/README.md b/ch04/README.md index a02eef5..0a95ce1 100644 --- a/ch04/README.md +++ b/ch04/README.md @@ -11,6 +11,8 @@ - [02_performance-analysis](02_performance-analysis) contains optional code analyzing the performance of the GPT model(s) implemented in the main chapter - [03_kv-cache](03_kv-cache) implements a KV cache to speed up the text generation during inference - [ch05/07_gpt_to_llama](../ch05/07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI (it might be interesting to look at alternative architectures after completing chapter 4, but you can also save that for after reading chapter 5) +- [04_gqa](04_gqa) contains an introduction to Grouped-Query Attention (GQA), which is used by most modern LLMs (Llama 4, gpt-oss, Qwen3, Gemma 3, and many more) as alternative to regular Multi-Head Attention (MHA) +- [05_mla](05_mla) contains an introduction to Multi-Head Latent Attention (MLA), which is used by DeepSeek V3, as alternative to regular Multi-Head Attention (MHA)