Files
LLMs-from-scratch/ch04/04_gqa/plot_memory_estimates.py
2025-10-11 09:11:33 -05:00

134 lines
3.5 KiB
Python

# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# Plot KV-cache
import matplotlib.pyplot as plt
# Import from ./memory_estimator.py
from memory_estimator import kv_bytes_total, DTYPE_BYTES
def bytes_convert(n):
gb = n / (1000 ** 3)
return f"{gb:.2f}"
def savings_percent(total_mha, total_gqa):
return (1.0 - (total_gqa / total_mha)) * 100.0
def plot_savings_vs_nkvgroups():
n_heads = 24
emb_dim = 2048
n_layers = 48
batch_size = 1
context_length = 8192
dtype = "bf16"
bytes_per_elem = DTYPE_BYTES[dtype]
total_mha = kv_bytes_total(
batch_size,
context_length,
emb_dim,
n_heads,
n_heads,
n_layers,
bytes_per_elem,
)
xs = []
ys = []
for n_kv_groups in range(1, n_heads + 1):
n_kv_heads = n_heads // n_kv_groups
total_gqa = kv_bytes_total(
batch_size,
context_length,
emb_dim,
n_heads,
n_kv_heads,
n_layers,
bytes_per_elem,
)
xs.append(n_kv_groups)
ys.append(savings_percent(total_mha, total_gqa))
plt.figure()
plt.plot(xs, ys, marker="o")
plt.xlabel("n_kv_groups")
plt.ylabel("Savings vs MHA (%)")
plt.title(
"KV-cache Savings vs n_kv_groups\n"
"(n_heads=24, emb_dim=2048, n_layers=48, "
"batch=1, context=8192, dtype=bf16)", fontsize=8
)
plt.grid(True)
plt.tight_layout()
plt.savefig("savings_vs_n_kv_groups.pdf")
def plot_abs_kv_vs_context():
n_heads = 24
emb_dim = 2048
n_layers = 48
batch_size = 1
n_kv_groups = 4
dtype = "bf16"
bytes_per_elem = DTYPE_BYTES[dtype]
n_kv_heads_mha = n_heads
n_kv_heads_gqa = n_heads // n_kv_groups
context_lengths = [
256, 512, 1024, 2048, 4096, 8192,
16384, 32768, 65536, 131072
]
xs = []
mha_gb = []
gqa_gb = []
savings_pct = None
for L in context_lengths:
total_mha = kv_bytes_total(
batch_size, L, emb_dim, n_heads,
n_kv_heads_mha, n_layers, bytes_per_elem
)
total_gqa = kv_bytes_total(
batch_size, L, emb_dim, n_heads,
n_kv_heads_gqa, n_layers, bytes_per_elem
)
xs.append(L)
mha_gb.append(float(bytes_convert(total_mha)))
gqa_gb.append(float(bytes_convert(total_gqa)))
if savings_pct is None:
savings_pct = savings_percent(total_mha, total_gqa)
plt.figure()
plt.plot(xs, mha_gb, marker="o", label="MHA (KV total)")
plt.plot(xs, gqa_gb, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
plt.xscale("log")
plt.xlabel("context_length (log scale)")
plt.ylabel("Total KV cache (GB)")
plt.title(
"KV-cache vs Context Length\n"
"(n_heads=24, emb_dim=2048, n_layers=48, "
"batch=1, n_kv_groups=4, dtype=bf16)", fontsize=8
)
plt.grid(True, which="both")
plt.legend()
plt.tight_layout()
plt.savefig("kv_bytes_vs_context_length.pdf")
print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.")
print(f"Example (context_length={context_lengths[-1]}): "
f"MHA={bytes_convert(total_mha)} GB, "
f"GQA={bytes_convert(total_gqa)} GB")
if __name__ == "__main__":
plot_savings_vs_nkvgroups()
plot_abs_kv_vs_context()