From bf27ad1485fa79b1261c980a8a0ec44612690c20 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sat, 11 Oct 2025 09:11:33 -0500 Subject: [PATCH] Use GB instead of GiB consistently (#875) --- ch04/04_gqa/memory_estimator.py | 7 +++++-- ch04/04_gqa/plot_memory_estimates.py | 20 ++++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/ch04/04_gqa/memory_estimator.py b/ch04/04_gqa/memory_estimator.py index 1151acb..276cc31 100644 --- a/ch04/04_gqa/memory_estimator.py +++ b/ch04/04_gqa/memory_estimator.py @@ -49,8 +49,11 @@ def main(): "n_kv_groups": args.n_kv_groups, } + if cfg["n_heads"] % cfg["n_kv_groups"] != 0: + raise ValueError("n_kv_groups must divide n_heads exactly.") + bytes_per_elem = DTYPE_BYTES[args.dtype] - head_dim = cfg["emb_dim"] / cfg["n_heads"] + head_dim = math.ceil(cfg["emb_dim"] / cfg["n_heads"]) n_kv_heads_mha = cfg["n_heads"] n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"] @@ -83,7 +86,7 @@ def main(): print(f"{k:17}: {v}") print(f"batch_size : {args.batch_size}") print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)") - print(f"head_dim : {int(head_dim)}") + print(f"head_dim : {head_dim}") print(f"GQA n_kv_heads : {n_kv_heads_gqa}") print() diff --git a/ch04/04_gqa/plot_memory_estimates.py b/ch04/04_gqa/plot_memory_estimates.py index 5476c1e..6e874d3 100644 --- a/ch04/04_gqa/plot_memory_estimates.py +++ b/ch04/04_gqa/plot_memory_estimates.py @@ -12,6 +12,11 @@ import matplotlib.pyplot as plt from memory_estimator import kv_bytes_total, DTYPE_BYTES +def bytes_convert(n): + gb = n / (1000 ** 3) + return f"{gb:.2f}" + + def savings_percent(total_mha, total_gqa): return (1.0 - (total_gqa / total_mha)) * 100.0 @@ -83,8 +88,8 @@ def plot_abs_kv_vs_context(): ] xs = [] - mha_gib = [] - gqa_gib = [] + mha_gb = [] + gqa_gb = [] savings_pct = None for L in context_lengths: @@ -97,14 +102,14 @@ def plot_abs_kv_vs_context(): n_kv_heads_gqa, n_layers, bytes_per_elem ) xs.append(L) - mha_gib.append(total_mha / (1024**3)) - gqa_gib.append(total_gqa / (1024**3)) + mha_gb.append(float(bytes_convert(total_mha))) + gqa_gb.append(float(bytes_convert(total_gqa))) if savings_pct is None: savings_pct = savings_percent(total_mha, total_gqa) plt.figure() - plt.plot(xs, mha_gib, marker="o", label="MHA (KV total)") - plt.plot(xs, gqa_gib, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})") + plt.plot(xs, mha_gb, marker="o", label="MHA (KV total)") + plt.plot(xs, gqa_gb, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})") plt.xscale("log") plt.xlabel("context_length (log scale)") plt.ylabel("Total KV cache (GB)") @@ -118,6 +123,9 @@ def plot_abs_kv_vs_context(): plt.tight_layout() plt.savefig("kv_bytes_vs_context_length.pdf") print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.") + print(f"Example (context_length={context_lengths[-1]}): " + f"MHA={bytes_convert(total_mha)} GB, " + f"GQA={bytes_convert(total_gqa)} GB") if __name__ == "__main__":