Use GB instead of GiB consistently (#875)

This commit is contained in:
Sebastian Raschka
2025-10-11 09:11:33 -05:00
committed by GitHub
parent c814814d72
commit bf27ad1485
2 changed files with 19 additions and 8 deletions

View File

@@ -49,8 +49,11 @@ def main():
"n_kv_groups": args.n_kv_groups,
}
if cfg["n_heads"] % cfg["n_kv_groups"] != 0:
raise ValueError("n_kv_groups must divide n_heads exactly.")
bytes_per_elem = DTYPE_BYTES[args.dtype]
head_dim = cfg["emb_dim"] / cfg["n_heads"]
head_dim = math.ceil(cfg["emb_dim"] / cfg["n_heads"])
n_kv_heads_mha = cfg["n_heads"]
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
@@ -83,7 +86,7 @@ def main():
print(f"{k:17}: {v}")
print(f"batch_size : {args.batch_size}")
print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)")
print(f"head_dim : {int(head_dim)}")
print(f"head_dim : {head_dim}")
print(f"GQA n_kv_heads : {n_kv_heads_gqa}")
print()

View File

@@ -12,6 +12,11 @@ import matplotlib.pyplot as plt
from memory_estimator import kv_bytes_total, DTYPE_BYTES
def bytes_convert(n):
gb = n / (1000 ** 3)
return f"{gb:.2f}"
def savings_percent(total_mha, total_gqa):
return (1.0 - (total_gqa / total_mha)) * 100.0
@@ -83,8 +88,8 @@ def plot_abs_kv_vs_context():
]
xs = []
mha_gib = []
gqa_gib = []
mha_gb = []
gqa_gb = []
savings_pct = None
for L in context_lengths:
@@ -97,14 +102,14 @@ def plot_abs_kv_vs_context():
n_kv_heads_gqa, n_layers, bytes_per_elem
)
xs.append(L)
mha_gib.append(total_mha / (1024**3))
gqa_gib.append(total_gqa / (1024**3))
mha_gb.append(float(bytes_convert(total_mha)))
gqa_gb.append(float(bytes_convert(total_gqa)))
if savings_pct is None:
savings_pct = savings_percent(total_mha, total_gqa)
plt.figure()
plt.plot(xs, mha_gib, marker="o", label="MHA (KV total)")
plt.plot(xs, gqa_gib, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
plt.plot(xs, mha_gb, marker="o", label="MHA (KV total)")
plt.plot(xs, gqa_gb, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
plt.xscale("log")
plt.xlabel("context_length (log scale)")
plt.ylabel("Total KV cache (GB)")
@@ -118,6 +123,9 @@ def plot_abs_kv_vs_context():
plt.tight_layout()
plt.savefig("kv_bytes_vs_context_length.pdf")
print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.")
print(f"Example (context_length={context_lengths[-1]}): "
f"MHA={bytes_convert(total_mha)} GB, "
f"GQA={bytes_convert(total_gqa)} GB")
if __name__ == "__main__":