mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Multi-Head Latent Attention (#876)
* Multi-Head Latent Attention * update
This commit is contained in:
committed by
GitHub
parent
bf27ad1485
commit
9b9586688d
@@ -2,12 +2,9 @@
|
||||
|
||||
This bonus material illustrates the memory savings when using Grouped-Query Attention (GQA) over regular Multi-Head Attention (MHA).
|
||||
|
||||
|
||||
|
||||
|
||||
## Introduction
|
||||
|
||||
|
||||
Grouped-Query Attention (GQA) has become the new standard replacement for a more compute- and parameter-efficient alternative to Multi-Head Attention (MHA) in recent years. Note that it's not new and goes back to the 2023 [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245). And even the larger variants in the good old Llama 2 series used it.
|
||||
|
||||
Here's a brief GQA summary. Unlike MHA, where each head also has its own set of keys and values, to reduce memory usage, GQA groups multiple heads to share the same key and value projections.
|
||||
@@ -28,8 +25,6 @@ While GQA is mainly a computational-efficiency workaround for MHA, ablation stud
|
||||
|
||||
However, this assumes that the number of key-value groups is chosen carefully. However, if we set the number of key-value heads equal to the number of heads (this special case is known as multi-query attention), it will negatively affect the modeling performance.
|
||||
|
||||
|
||||
|
||||
|
||||
## GQA Memory Savings
|
||||
|
||||
@@ -37,10 +32,10 @@ The memory savings are mostly reflected in the KV storage. We can compute the KV
|
||||
|
||||
bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads
|
||||
|
||||
You can use the [memory_estimator.py](memory_estimator.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA:
|
||||
You can use the [memory_estimator_gqa.py](memory_estimator_gqa.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA:
|
||||
|
||||
```bash
|
||||
➜ uv run memory_estimator.py \
|
||||
➜ uv run memory_estimator_gqa.py \
|
||||
--emb_dim 4096 --n_heads 32 --n_layers 32 \
|
||||
--context_length 32768 --n_kv_groups 4 \
|
||||
--batch_size 1 --dtype bf16
|
||||
@@ -62,25 +57,15 @@ Ratio (MHA / GQA) : 4.00x
|
||||
Savings (GQA vs MHA): 75.00%
|
||||
```
|
||||
|
||||
The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes:
|
||||
The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes as a function of the context length:
|
||||
|
||||
|
||||
|
||||
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/2.webp?2" alt="GQA" width="500px" />
|
||||
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/3.webp?4" alt="GQA" width="500px" />
|
||||
|
||||
|
||||
|
||||
And the following plot shows how the KV cache size grows with an increasing context length:
|
||||
|
||||
|
||||
|
||||
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/3.webp?2" alt="GQA" width="500px" />
|
||||
|
||||
|
||||
|
||||
You can reproduce these plots via `uv run plot_memory_estimates.py`.
|
||||
|
||||
|
||||
You can reproduce the plot via `uv run plot_memory_estimates_gqa.py`.
|
||||
|
||||
|
||||
## GQA Code Examples
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
# This file collects all the relevant code that we covered thus far
|
||||
# throughout Chapters 3-4.
|
||||
# throughout Chapters 3-4, adapted to use Grouped-Query Attention (GQA).
|
||||
# This file can be run as a standalone script.
|
||||
|
||||
import argparse
|
||||
@@ -83,7 +83,8 @@ class GroupedQueryAttention(nn.Module):
|
||||
# Shape: (b, num_heads, num_tokens, num_tokens)
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
# Use the mask to fill attention scores
|
||||
####################################################
|
||||
# causal mask
|
||||
num_tokens_Q = queries.shape[-2]
|
||||
num_tokens_K = keys.shape[-2]
|
||||
device = queries.device
|
||||
@@ -101,6 +102,7 @@ class GroupedQueryAttention(nn.Module):
|
||||
k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long)
|
||||
mask = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0)
|
||||
|
||||
# Use the mask to fill attention scores
|
||||
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||
@@ -111,7 +113,7 @@ class GroupedQueryAttention(nn.Module):
|
||||
context_vec = (attn_weights @ values).transpose(1, 2)
|
||||
|
||||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
|
||||
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
@@ -184,7 +186,7 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
x = self.att(x, use_cache=use_cache)
|
||||
####################################################
|
||||
|
||||
@@ -211,7 +213,7 @@ class GPTModel(nn.Module):
|
||||
# self.trf_blocks = nn.Sequential(
|
||||
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
self.trf_blocks = nn.ModuleList(
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
|
||||
@@ -228,8 +230,7 @@ class GPTModel(nn.Module):
|
||||
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
|
||||
# KV cache-related
|
||||
if use_cache:
|
||||
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||
self.current_pos += seq_len
|
||||
@@ -243,7 +244,7 @@ class GPTModel(nn.Module):
|
||||
|
||||
# x = self.trf_blocks(x)
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
for blk in self.trf_blocks:
|
||||
x = blk(x, use_cache=use_cache)
|
||||
####################################################
|
||||
@@ -253,7 +254,7 @@ class GPTModel(nn.Module):
|
||||
return logits
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
@@ -261,34 +262,6 @@ class GPTModel(nn.Module):
|
||||
####################################################
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
# idx is (B, T) array of indices in the current context
|
||||
for _ in range(max_new_tokens):
|
||||
|
||||
# Crop current context if it exceeds the supported context size
|
||||
# E.g., if LLM supports only 5 tokens, and the context size is 10
|
||||
# then only the last 5 tokens are used as context
|
||||
idx_cond = idx[:, -context_size:]
|
||||
|
||||
# Get the predictions
|
||||
with torch.no_grad():
|
||||
logits = model(idx_cond)
|
||||
|
||||
# Focus only on the last time step
|
||||
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
# Get the idx of the vocab entry with the highest logits value
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
|
||||
|
||||
# Append sampled index to the running sequence
|
||||
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens,
|
||||
context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
@@ -314,7 +287,6 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
####################################################
|
||||
|
||||
|
||||
def main():
|
||||
@@ -324,6 +296,7 @@ def main():
|
||||
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")
|
||||
parser.add_argument("--n_kv_groups", type=int, default=2, help="Number of key/value groups.")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
start_context = "Hello, I am"
|
||||
|
||||
@@ -33,7 +33,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related code
|
||||
self.register_buffer("cache_k", None, persistent=False)
|
||||
self.register_buffer("cache_v", None, persistent=False)
|
||||
self.ptr_current_pos = 0
|
||||
@@ -53,7 +53,7 @@ class MultiHeadAttention(nn.Module):
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
if use_cache:
|
||||
if self.cache_k is None:
|
||||
self.cache_k, self.cache_v = keys_new, values_new
|
||||
@@ -74,7 +74,7 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
# causal mask
|
||||
num_tokens_Q = queries.shape[-2]
|
||||
num_tokens_K = keys.shape[-2]
|
||||
device = queries.device
|
||||
@@ -107,12 +107,9 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
return context_vec
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def reset_cache(self):
|
||||
self.cache_k, self.cache_v = None, None
|
||||
self.ptr_current_pos = 0
|
||||
####################################################
|
||||
|
||||
|
||||
#####################################
|
||||
@@ -177,7 +174,7 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
x = self.att(x, use_cache=use_cache)
|
||||
####################################################
|
||||
|
||||
@@ -204,7 +201,7 @@ class GPTModel(nn.Module):
|
||||
# self.trf_blocks = nn.Sequential(
|
||||
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
self.trf_blocks = nn.ModuleList(
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
|
||||
@@ -221,8 +218,7 @@ class GPTModel(nn.Module):
|
||||
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
|
||||
# KV cache-related
|
||||
if use_cache:
|
||||
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||
self.current_pos += seq_len
|
||||
@@ -236,7 +232,7 @@ class GPTModel(nn.Module):
|
||||
|
||||
# x = self.trf_blocks(x)
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
for blk in self.trf_blocks:
|
||||
x = blk(x, use_cache=use_cache)
|
||||
####################################################
|
||||
@@ -246,7 +242,7 @@ class GPTModel(nn.Module):
|
||||
return logits
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
# KV cache-related
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
@@ -254,34 +250,6 @@ class GPTModel(nn.Module):
|
||||
####################################################
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
# idx is (B, T) array of indices in the current context
|
||||
for _ in range(max_new_tokens):
|
||||
|
||||
# Crop current context if it exceeds the supported context size
|
||||
# E.g., if LLM supports only 5 tokens, and the context size is 10
|
||||
# then only the last 5 tokens are used as context
|
||||
idx_cond = idx[:, -context_size:]
|
||||
|
||||
# Get the predictions
|
||||
with torch.no_grad():
|
||||
logits = model(idx_cond)
|
||||
|
||||
# Focus only on the last time step
|
||||
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
# Get the idx of the vocab entry with the highest logits value
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
|
||||
|
||||
# Append sampled index to the running sequence
|
||||
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens,
|
||||
context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
@@ -307,7 +275,6 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
####################################################
|
||||
|
||||
|
||||
def main():
|
||||
@@ -316,6 +283,7 @@ def main():
|
||||
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
|
||||
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
start_context = "Hello, I am"
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
# Plot KV-cache
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Import from ./memory_estimator.py
|
||||
from memory_estimator import kv_bytes_total, DTYPE_BYTES
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
gb = n / (1000 ** 3)
|
||||
return f"{gb:.2f}"
|
||||
|
||||
|
||||
def savings_percent(total_mha, total_gqa):
|
||||
return (1.0 - (total_gqa / total_mha)) * 100.0
|
||||
|
||||
|
||||
def plot_savings_vs_nkvgroups():
|
||||
n_heads = 24
|
||||
emb_dim = 2048
|
||||
n_layers = 48
|
||||
batch_size = 1
|
||||
context_length = 8192
|
||||
dtype = "bf16"
|
||||
bytes_per_elem = DTYPE_BYTES[dtype]
|
||||
|
||||
total_mha = kv_bytes_total(
|
||||
batch_size,
|
||||
context_length,
|
||||
emb_dim,
|
||||
n_heads,
|
||||
n_heads,
|
||||
n_layers,
|
||||
bytes_per_elem,
|
||||
)
|
||||
|
||||
xs = []
|
||||
ys = []
|
||||
for n_kv_groups in range(1, n_heads + 1):
|
||||
n_kv_heads = n_heads // n_kv_groups
|
||||
total_gqa = kv_bytes_total(
|
||||
batch_size,
|
||||
context_length,
|
||||
emb_dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
n_layers,
|
||||
bytes_per_elem,
|
||||
)
|
||||
xs.append(n_kv_groups)
|
||||
ys.append(savings_percent(total_mha, total_gqa))
|
||||
|
||||
plt.figure()
|
||||
plt.plot(xs, ys, marker="o")
|
||||
plt.xlabel("n_kv_groups")
|
||||
plt.ylabel("Savings vs MHA (%)")
|
||||
plt.title(
|
||||
"KV-cache Savings vs n_kv_groups\n"
|
||||
"(n_heads=24, emb_dim=2048, n_layers=48, "
|
||||
"batch=1, context=8192, dtype=bf16)", fontsize=8
|
||||
)
|
||||
plt.grid(True)
|
||||
plt.tight_layout()
|
||||
plt.savefig("savings_vs_n_kv_groups.pdf")
|
||||
|
||||
|
||||
def plot_abs_kv_vs_context():
|
||||
n_heads = 24
|
||||
emb_dim = 2048
|
||||
n_layers = 48
|
||||
batch_size = 1
|
||||
n_kv_groups = 4
|
||||
dtype = "bf16"
|
||||
bytes_per_elem = DTYPE_BYTES[dtype]
|
||||
|
||||
n_kv_heads_mha = n_heads
|
||||
n_kv_heads_gqa = n_heads // n_kv_groups
|
||||
|
||||
context_lengths = [
|
||||
256, 512, 1024, 2048, 4096, 8192,
|
||||
16384, 32768, 65536, 131072
|
||||
]
|
||||
|
||||
xs = []
|
||||
mha_gb = []
|
||||
gqa_gb = []
|
||||
savings_pct = None
|
||||
|
||||
for L in context_lengths:
|
||||
total_mha = kv_bytes_total(
|
||||
batch_size, L, emb_dim, n_heads,
|
||||
n_kv_heads_mha, n_layers, bytes_per_elem
|
||||
)
|
||||
total_gqa = kv_bytes_total(
|
||||
batch_size, L, emb_dim, n_heads,
|
||||
n_kv_heads_gqa, n_layers, bytes_per_elem
|
||||
)
|
||||
xs.append(L)
|
||||
mha_gb.append(float(bytes_convert(total_mha)))
|
||||
gqa_gb.append(float(bytes_convert(total_gqa)))
|
||||
if savings_pct is None:
|
||||
savings_pct = savings_percent(total_mha, total_gqa)
|
||||
|
||||
plt.figure()
|
||||
plt.plot(xs, mha_gb, marker="o", label="MHA (KV total)")
|
||||
plt.plot(xs, gqa_gb, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
|
||||
plt.xscale("log")
|
||||
plt.xlabel("context_length (log scale)")
|
||||
plt.ylabel("Total KV cache (GB)")
|
||||
plt.title(
|
||||
"KV-cache vs Context Length\n"
|
||||
"(n_heads=24, emb_dim=2048, n_layers=48, "
|
||||
"batch=1, n_kv_groups=4, dtype=bf16)", fontsize=8
|
||||
)
|
||||
plt.grid(True, which="both")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig("kv_bytes_vs_context_length.pdf")
|
||||
print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.")
|
||||
print(f"Example (context_length={context_lengths[-1]}): "
|
||||
f"MHA={bytes_convert(total_mha)} GB, "
|
||||
f"GQA={bytes_convert(total_gqa)} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
plot_savings_vs_nkvgroups()
|
||||
plot_abs_kv_vs_context()
|
||||
81
ch04/04_gqa/plot_memory_estimates_gqa.py
Normal file
81
ch04/04_gqa/plot_memory_estimates_gqa.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
# Plot KV-cache vs context length for different n_kv_groups
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Import from ./memory_estimator.py
|
||||
from memory_estimator_gqa import kv_bytes_total, DTYPE_BYTES
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
gb = n / (1000 ** 3)
|
||||
return f"{gb:.2f}"
|
||||
|
||||
|
||||
def savings_percent(total_mha, total_gqa):
|
||||
return (1.0 - (total_gqa / total_mha)) * 100.0
|
||||
|
||||
|
||||
def plot_abs_kv_vs_context_multi_groups():
|
||||
n_heads = 24
|
||||
emb_dim = 2048
|
||||
n_layers = 48
|
||||
batch_size = 1
|
||||
dtype = "bf16"
|
||||
bytes_per_elem = DTYPE_BYTES[dtype]
|
||||
|
||||
# x-axis (log scale)
|
||||
context_lengths = [
|
||||
256, 512, 1024, 2048, 4096, 8192,
|
||||
16384, 32768, 65536, 131072
|
||||
]
|
||||
|
||||
mha_gb = []
|
||||
for L in context_lengths:
|
||||
total_mha = kv_bytes_total(
|
||||
batch_size, L, emb_dim, n_heads,
|
||||
n_heads, # MHA: n_kv_heads = n_heads
|
||||
n_layers, bytes_per_elem
|
||||
)
|
||||
mha_gb.append(float(bytes_convert(total_mha)))
|
||||
|
||||
plt.figure()
|
||||
plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)")
|
||||
|
||||
# GQA curves for selected n_kv_groups
|
||||
groups_list = [4, 8, 12, 24]
|
||||
for g in groups_list:
|
||||
n_kv_heads = n_heads // g
|
||||
gqa_gb = []
|
||||
for L in context_lengths:
|
||||
total_gqa = kv_bytes_total(
|
||||
batch_size, L, emb_dim, n_heads,
|
||||
n_kv_heads, n_layers, bytes_per_elem
|
||||
)
|
||||
gqa_gb.append(float(bytes_convert(total_gqa)))
|
||||
|
||||
# Compression rate relative to MHA
|
||||
comp = (n_heads / n_kv_heads) if n_kv_heads > 0 else float("inf")
|
||||
plt.plot(context_lengths, gqa_gb, marker="o",
|
||||
label=f"GQA (n_kv_groups={g}, {comp:,.1f}× compression)")
|
||||
|
||||
plt.xscale("log")
|
||||
plt.xlabel("context_length (log scale)")
|
||||
plt.ylabel("Total KV cache (GB)")
|
||||
plt.title(
|
||||
"KV-cache vs Context Length — MHA vs GQA (multi-group)\n"
|
||||
"(n_heads=24, emb_dim=2048, n_layers=48, batch=1, dtype=bf16)",
|
||||
fontsize=8
|
||||
)
|
||||
plt.grid(True, which="both")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig("kv_bytes_vs_context_length.pdf")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
plot_abs_kv_vs_context_multi_groups()
|
||||
Reference in New Issue
Block a user