mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Use GB instead of GiB consistently (#875)
This commit is contained in:
committed by
GitHub
parent
c814814d72
commit
bf27ad1485
@@ -49,8 +49,11 @@ def main():
|
||||
"n_kv_groups": args.n_kv_groups,
|
||||
}
|
||||
|
||||
if cfg["n_heads"] % cfg["n_kv_groups"] != 0:
|
||||
raise ValueError("n_kv_groups must divide n_heads exactly.")
|
||||
|
||||
bytes_per_elem = DTYPE_BYTES[args.dtype]
|
||||
head_dim = cfg["emb_dim"] / cfg["n_heads"]
|
||||
head_dim = math.ceil(cfg["emb_dim"] / cfg["n_heads"])
|
||||
|
||||
n_kv_heads_mha = cfg["n_heads"]
|
||||
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
|
||||
@@ -83,7 +86,7 @@ def main():
|
||||
print(f"{k:17}: {v}")
|
||||
print(f"batch_size : {args.batch_size}")
|
||||
print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)")
|
||||
print(f"head_dim : {int(head_dim)}")
|
||||
print(f"head_dim : {head_dim}")
|
||||
print(f"GQA n_kv_heads : {n_kv_heads_gqa}")
|
||||
print()
|
||||
|
||||
|
||||
@@ -12,6 +12,11 @@ import matplotlib.pyplot as plt
|
||||
from memory_estimator import kv_bytes_total, DTYPE_BYTES
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
gb = n / (1000 ** 3)
|
||||
return f"{gb:.2f}"
|
||||
|
||||
|
||||
def savings_percent(total_mha, total_gqa):
|
||||
return (1.0 - (total_gqa / total_mha)) * 100.0
|
||||
|
||||
@@ -83,8 +88,8 @@ def plot_abs_kv_vs_context():
|
||||
]
|
||||
|
||||
xs = []
|
||||
mha_gib = []
|
||||
gqa_gib = []
|
||||
mha_gb = []
|
||||
gqa_gb = []
|
||||
savings_pct = None
|
||||
|
||||
for L in context_lengths:
|
||||
@@ -97,14 +102,14 @@ def plot_abs_kv_vs_context():
|
||||
n_kv_heads_gqa, n_layers, bytes_per_elem
|
||||
)
|
||||
xs.append(L)
|
||||
mha_gib.append(total_mha / (1024**3))
|
||||
gqa_gib.append(total_gqa / (1024**3))
|
||||
mha_gb.append(float(bytes_convert(total_mha)))
|
||||
gqa_gb.append(float(bytes_convert(total_gqa)))
|
||||
if savings_pct is None:
|
||||
savings_pct = savings_percent(total_mha, total_gqa)
|
||||
|
||||
plt.figure()
|
||||
plt.plot(xs, mha_gib, marker="o", label="MHA (KV total)")
|
||||
plt.plot(xs, gqa_gib, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
|
||||
plt.plot(xs, mha_gb, marker="o", label="MHA (KV total)")
|
||||
plt.plot(xs, gqa_gb, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
|
||||
plt.xscale("log")
|
||||
plt.xlabel("context_length (log scale)")
|
||||
plt.ylabel("Total KV cache (GB)")
|
||||
@@ -118,6 +123,9 @@ def plot_abs_kv_vs_context():
|
||||
plt.tight_layout()
|
||||
plt.savefig("kv_bytes_vs_context_length.pdf")
|
||||
print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.")
|
||||
print(f"Example (context_length={context_lengths[-1]}): "
|
||||
f"MHA={bytes_convert(total_mha)} GB, "
|
||||
f"GQA={bytes_convert(total_gqa)} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user