mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Readability and code quality improvements (#959)
* Consistent dataset naming * consistent section headers
This commit is contained in:
committed by
GitHub
parent
7b1f740f74
commit
be5e2a3331
@@ -73,6 +73,7 @@
|
||||
"id": "53fe99ab-0bcf-4778-a6b5-6db81fb826ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 4.1 Coding an LLM architecture"
|
||||
]
|
||||
},
|
||||
@@ -323,6 +324,7 @@
|
||||
"id": "f8332a00-98da-4eb4-b882-922776a89917",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 4.2 Normalizing activations with layer normalization"
|
||||
]
|
||||
},
|
||||
@@ -606,6 +608,7 @@
|
||||
"id": "11190e7d-8c29-4115-824a-e03702f9dd54",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 4.3 Implementing a feed forward network with GELU activations"
|
||||
]
|
||||
},
|
||||
@@ -789,6 +792,7 @@
|
||||
"id": "4ffcb905-53c7-4886-87d2-4464c5fecf89",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 4.4 Adding shortcut connections"
|
||||
]
|
||||
},
|
||||
@@ -950,6 +954,7 @@
|
||||
"id": "cae578ca-e564-42cf-8635-a2267047cdff",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 4.5 Connecting attention and linear layers in a transformer block"
|
||||
]
|
||||
},
|
||||
@@ -1068,6 +1073,7 @@
|
||||
"id": "46618527-15ac-4c32-ad85-6cfea83e006e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 4.6 Coding the GPT model"
|
||||
]
|
||||
},
|
||||
@@ -1332,6 +1338,7 @@
|
||||
"id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 4.7 Generating text"
|
||||
]
|
||||
},
|
||||
@@ -1519,11 +1526,20 @@
|
||||
"id": "a35278b6-9e5c-480f-83e5-011a1173648f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## Summary and takeaways\n",
|
||||
"\n",
|
||||
"- See the [./gpt.py](./gpt.py) script, a self-contained script containing the GPT model we implement in this Jupyter notebook\n",
|
||||
"- You can find the exercise solutions in [./exercise-solutions.ipynb](./exercise-solutions.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4821ac83-ef84-42c4-a327-32bf2820a8e5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -53,7 +53,8 @@
|
||||
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 4.1: Parameters in the feed forward versus attention module"
|
||||
" \n",
|
||||
"## Exercise 4.1: Parameters in the feed forward versus attention module"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -182,7 +183,8 @@
|
||||
"id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 4.2: Initialize larger GPT models"
|
||||
" \n",
|
||||
"## Exercise 4.2: Initialize larger GPT models"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -329,7 +331,8 @@
|
||||
"id": "f5f2306e-5dc8-498e-92ee-70ae7ec37ac1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 4.3: Using separate dropout parameters"
|
||||
" \n",
|
||||
"## Exercise 4.3: Using separate dropout parameters"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -451,7 +454,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -18,13 +18,13 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
def convert_bytes(n):
|
||||
gb = n / (1000 ** 3)
|
||||
return f"{gb:,.2f} GB"
|
||||
|
||||
|
||||
def kv_bytes_total(batch, context_length, emb_dim, n_heads,
|
||||
n_kv_heads, n_layers, bytes_per_elem):
|
||||
def calc_kv_bytes_total(batch, context_length, emb_dim, n_heads,
|
||||
n_kv_heads, n_layers, bytes_per_elem):
|
||||
head_dim = math.ceil(emb_dim / n_heads)
|
||||
per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
|
||||
return per_layer * n_layers
|
||||
@@ -58,7 +58,7 @@ def main():
|
||||
n_kv_heads_mha = cfg["n_heads"]
|
||||
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
|
||||
|
||||
total_mha = kv_bytes_total(
|
||||
total_mha = calc_kv_bytes_total(
|
||||
args.batch_size,
|
||||
cfg["context_length"],
|
||||
cfg["emb_dim"],
|
||||
@@ -68,7 +68,7 @@ def main():
|
||||
bytes_per_elem,
|
||||
)
|
||||
|
||||
total_gqa = kv_bytes_total(
|
||||
total_gqa = calc_kv_bytes_total(
|
||||
args.batch_size,
|
||||
cfg["context_length"],
|
||||
cfg["emb_dim"],
|
||||
@@ -91,8 +91,8 @@ def main():
|
||||
print()
|
||||
|
||||
print("==== KV-cache totals across all layers ====")
|
||||
print(f"MHA total KV cache : {bytes_convert(total_mha)}")
|
||||
print(f"GQA total KV cache : {bytes_convert(total_gqa)}")
|
||||
print(f"MHA total KV cache : {convert_bytes(total_mha)}")
|
||||
print(f"GQA total KV cache : {convert_bytes(total_gqa)}")
|
||||
print(f"Ratio (MHA / GQA) : {ratio:,.2f}x")
|
||||
print(f"Savings (GQA vs MHA): {savings*100:,.2f}%")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Import from ./memory_estimator.py
|
||||
from memory_estimator_gqa import kv_bytes_total, DTYPE_BYTES
|
||||
from memory_estimator_gqa import calc_kv_bytes_total, DTYPE_BYTES
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
@@ -36,7 +36,7 @@ def plot_abs_kv_vs_context_multi_groups():
|
||||
|
||||
mha_gb = []
|
||||
for L in context_lengths:
|
||||
total_mha = kv_bytes_total(
|
||||
total_mha = calc_kv_bytes_total(
|
||||
batch_size, L, emb_dim, n_heads,
|
||||
n_heads, # MHA: n_kv_heads = n_heads
|
||||
n_layers, bytes_per_elem
|
||||
@@ -52,7 +52,7 @@ def plot_abs_kv_vs_context_multi_groups():
|
||||
n_kv_heads = n_heads // g
|
||||
gqa_gb = []
|
||||
for L in context_lengths:
|
||||
total_gqa = kv_bytes_total(
|
||||
total_gqa = calc_kv_bytes_total(
|
||||
batch_size, L, emb_dim, n_heads,
|
||||
n_kv_heads, n_layers, bytes_per_elem
|
||||
)
|
||||
|
||||
@@ -17,20 +17,20 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
def convert_bytes(n):
|
||||
gb = n / (1000 ** 3)
|
||||
return f"{gb:,.2f} GB"
|
||||
|
||||
|
||||
def kv_bytes_total(batch, context_length, emb_dim, n_heads,
|
||||
n_kv_heads, n_layers, bytes_per_elem):
|
||||
def calc_kv_bytes_total(batch, context_length, emb_dim, n_heads,
|
||||
n_kv_heads, n_layers, bytes_per_elem):
|
||||
# Generic KV-cache: per-head dim is embed_dim / n_heads, times 2 for K and V
|
||||
head_dim = math.ceil(emb_dim / n_heads)
|
||||
per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
|
||||
return per_layer * n_layers
|
||||
|
||||
|
||||
def mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem):
|
||||
def calc_mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem):
|
||||
# Simple MLA (per-token compressed latent)
|
||||
# bytes ≈ batch × seqlen × n_layers × latent_dim × bytes_per_elem
|
||||
return batch * context_length * n_layers * latent_dim * bytes_per_elem
|
||||
@@ -66,7 +66,7 @@ def main():
|
||||
n_kv_heads_mha = cfg["n_heads"]
|
||||
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
|
||||
|
||||
total_mha = kv_bytes_total(
|
||||
total_mha = calc_kv_bytes_total(
|
||||
args.batch_size,
|
||||
cfg["context_length"],
|
||||
cfg["emb_dim"],
|
||||
@@ -76,7 +76,7 @@ def main():
|
||||
bytes_per_elem,
|
||||
)
|
||||
|
||||
total_gqa = kv_bytes_total(
|
||||
total_gqa = calc_kv_bytes_total(
|
||||
args.batch_size,
|
||||
cfg["context_length"],
|
||||
cfg["emb_dim"],
|
||||
@@ -86,7 +86,7 @@ def main():
|
||||
bytes_per_elem,
|
||||
)
|
||||
|
||||
total_mla = mla_bytes_total(
|
||||
total_mla = calc_mla_bytes_total(
|
||||
args.batch_size,
|
||||
cfg["context_length"],
|
||||
cfg["n_layers"],
|
||||
@@ -110,9 +110,9 @@ def main():
|
||||
print()
|
||||
|
||||
print("==== KV-cache totals across all layers ====")
|
||||
print(f"MHA total KV cache : {bytes_convert(total_mha)}")
|
||||
print(f"GQA total KV cache : {bytes_convert(total_gqa)}")
|
||||
print(f"MLA total KV cache : {bytes_convert(total_mla)}")
|
||||
print(f"MHA total KV cache : {convert_bytes(total_mha)}")
|
||||
print(f"GQA total KV cache : {convert_bytes(total_gqa)}")
|
||||
print(f"MLA total KV cache : {convert_bytes(total_mla)}")
|
||||
print(f"Ratio (MHA / GQA) : {ratio:,.2f}x")
|
||||
print(f"Savings (GQA vs MHA): {savings*100:,.2f}%")
|
||||
print(f"Ratio (MHA / MLA) : {ratio_mha_mla:,.2f}x")
|
||||
|
||||
@@ -15,18 +15,18 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def bytes_to_gb(n_bytes):
|
||||
def convert_bytes_to_gb(n_bytes):
|
||||
return n_bytes / (1000. ** 3)
|
||||
|
||||
|
||||
def kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
|
||||
n_layers, bytes_per_elem):
|
||||
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
|
||||
n_layers, bytes_per_elem):
|
||||
head_dim = emb_dim / n_heads
|
||||
per_layer = batch * context_length * head_dim * n_heads * 2 * bytes_per_elem
|
||||
return per_layer * n_layers
|
||||
|
||||
|
||||
def kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
|
||||
def calc_kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
|
||||
return batch * context_length * n_layers * latent_dim * bytes_per_elem
|
||||
|
||||
|
||||
@@ -45,27 +45,27 @@ def plot_abs_kv_vs_context_multiple():
|
||||
|
||||
mha_gb = []
|
||||
for L in context_lengths:
|
||||
total_mha = kv_bytes_total_mha(
|
||||
total_mha = calc_kv_bytes_total_mha(
|
||||
batch_size, L, emb_dim, n_heads, n_layers, bytes_per_elem
|
||||
)
|
||||
mha_gb.append(bytes_to_gb(total_mha))
|
||||
mha_gb.append(convert_bytes_to_gb(total_mha))
|
||||
|
||||
latent_dims = [1024, 512, 256, 64]
|
||||
plt.figure()
|
||||
plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)")
|
||||
|
||||
L_ref = context_lengths[-1]
|
||||
total_mha_ref = kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
|
||||
total_mha_ref = calc_kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
|
||||
|
||||
for latent_dim in latent_dims:
|
||||
mla_gb = []
|
||||
for L in context_lengths:
|
||||
total_mla = kv_bytes_total_mla(
|
||||
total_mla = calc_kv_bytes_total_mla(
|
||||
batch_size, L, n_layers, latent_dim, bytes_per_elem
|
||||
)
|
||||
mla_gb.append(bytes_to_gb(total_mla))
|
||||
mla_gb.append(convert_bytes_to_gb(total_mla))
|
||||
|
||||
total_mla_ref = kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
|
||||
total_mla_ref = calc_kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
|
||||
comp = total_mha_ref / total_mla_ref if total_mla_ref != 0 else float("inf")
|
||||
|
||||
plt.plot(context_lengths, mla_gb, marker="o",
|
||||
|
||||
@@ -17,12 +17,12 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
def convert_bytes(n):
|
||||
gb = n / (1000 ** 3)
|
||||
return f"{gb:,.2f} GB"
|
||||
|
||||
|
||||
def kv_bytes_per_layer(batch, context_length, head_dim, n_kv_heads, bytes_per_elem):
|
||||
def calc_kv_bytes_per_layer(batch, context_length, head_dim, n_kv_heads, bytes_per_elem):
|
||||
# KV = batch * tokens * head_dim * n_kv_heads * 2 (K,V) * bytes
|
||||
return batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
|
||||
|
||||
@@ -64,10 +64,10 @@ def estimate_totals(context_length, sliding_window_size, emb_dim, n_heads, n_lay
|
||||
L = context_length
|
||||
|
||||
# Per-layer costs
|
||||
per_mha_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_mha, bytes_per_elem)
|
||||
per_gqa_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_gqa, bytes_per_elem)
|
||||
per_mha_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_mha, bytes_per_elem)
|
||||
per_gqa_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_gqa, bytes_per_elem)
|
||||
per_mha_full = calc_kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_mha, bytes_per_elem)
|
||||
per_gqa_full = calc_kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_gqa, bytes_per_elem)
|
||||
per_mha_swa = calc_kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_mha, bytes_per_elem)
|
||||
per_gqa_swa = calc_kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_gqa, bytes_per_elem)
|
||||
|
||||
# Totals
|
||||
total_mha_allfull = per_mha_full * n_layers
|
||||
@@ -140,10 +140,10 @@ def main():
|
||||
print()
|
||||
|
||||
print("==== KV-cache totals across all layers ====")
|
||||
print(f"MHA KV total : {bytes_convert(res['total_mha_allfull'])}")
|
||||
print(f"GQA KV total : {bytes_convert(res['total_gqa_allfull'])}")
|
||||
print(f"MHA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_mha'])}")
|
||||
print(f"GQA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_gqa'])}")
|
||||
print(f"MHA KV total : {convert_bytes(res['total_mha_allfull'])}")
|
||||
print(f"GQA KV total : {convert_bytes(res['total_gqa_allfull'])}")
|
||||
print(f"MHA + SWA (ratio {args.swa_ratio}) : {convert_bytes(res['total_mixed_mha'])}")
|
||||
print(f"GQA + SWA (ratio {args.swa_ratio}) : {convert_bytes(res['total_mixed_gqa'])}")
|
||||
print()
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def bytes_to_gb(n_bytes):
|
||||
def convert_bytes_to_gb(n_bytes):
|
||||
return n_bytes / (1000.0 ** 3)
|
||||
|
||||
|
||||
@@ -39,22 +39,22 @@ def parse_ratio(ratio_str):
|
||||
raise ValueError("--swa_ratio must be in the form 'a:b' with nonnegative integers and a+b>0")
|
||||
|
||||
|
||||
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem):
|
||||
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem):
|
||||
# For MHA, n_kv_heads = n_heads, which cancels out:
|
||||
# total = B * L * E * 2 (K,V) * bytes * n_layers
|
||||
return batch * context_length * emb_dim * 2 * bytes_per_elem * n_layers
|
||||
|
||||
|
||||
def kv_bytes_total_gqa(
|
||||
def calc_kv_bytes_total_gqa(
|
||||
batch, context_length, emb_dim, n_layers, bytes_per_elem, n_kv_groups
|
||||
):
|
||||
# For GQA, n_kv_heads = n_heads / n_kv_groups
|
||||
# => scale the MHA total by 1 / n_kv_groups
|
||||
base = kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem)
|
||||
base = calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem)
|
||||
return base / n_kv_groups
|
||||
|
||||
|
||||
def kv_bytes_total_mha_swa(
|
||||
def calc_kv_bytes_total_mha_swa(
|
||||
batch, context_length, emb_dim, n_layers, bytes_per_elem, window, swa_ratio
|
||||
):
|
||||
# Split layers into SWA vs Full
|
||||
@@ -63,16 +63,16 @@ def kv_bytes_total_mha_swa(
|
||||
n_swa_layers = int(round(n_layers * (a / total_blocks)))
|
||||
n_full_layers = n_layers - n_swa_layers
|
||||
|
||||
total_full = kv_bytes_total_mha(
|
||||
total_full = calc_kv_bytes_total_mha(
|
||||
batch, context_length, emb_dim, n_full_layers, bytes_per_elem
|
||||
)
|
||||
total_swa = kv_bytes_total_mha(
|
||||
total_swa = calc_kv_bytes_total_mha(
|
||||
batch, window, emb_dim, n_swa_layers, bytes_per_elem
|
||||
)
|
||||
return total_full + total_swa
|
||||
|
||||
|
||||
def kv_bytes_total_gqa_swa(
|
||||
def calc_kv_bytes_total_gqa_swa(
|
||||
batch,
|
||||
context_length,
|
||||
emb_dim,
|
||||
@@ -87,7 +87,7 @@ def kv_bytes_total_gqa_swa(
|
||||
n_swa_layers = int(round(n_layers * (a / total_blocks)))
|
||||
n_full_layers = n_layers - n_swa_layers
|
||||
|
||||
total_full = kv_bytes_total_gqa(
|
||||
total_full = calc_kv_bytes_total_gqa(
|
||||
batch,
|
||||
context_length,
|
||||
emb_dim,
|
||||
@@ -95,7 +95,7 @@ def kv_bytes_total_gqa_swa(
|
||||
bytes_per_elem,
|
||||
n_kv_groups,
|
||||
)
|
||||
total_swa = kv_bytes_total_gqa(
|
||||
total_swa = calc_kv_bytes_total_gqa(
|
||||
batch, window, emb_dim, n_swa_layers, bytes_per_elem, n_kv_groups
|
||||
)
|
||||
return total_full + total_swa
|
||||
@@ -144,10 +144,10 @@ def main():
|
||||
] = []
|
||||
|
||||
for L in context_lengths:
|
||||
total_mha = kv_bytes_total_mha(
|
||||
total_mha = calc_kv_bytes_total_mha(
|
||||
batch_size, L, emb_dim, n_layers, bytes_per_elem
|
||||
)
|
||||
total_mha_swa = kv_bytes_total_mha_swa(
|
||||
total_mha_swa = calc_kv_bytes_total_mha_swa(
|
||||
batch_size,
|
||||
L,
|
||||
emb_dim,
|
||||
@@ -156,16 +156,16 @@ def main():
|
||||
window=args.sliding_window_size,
|
||||
swa_ratio=args.swa_ratio,
|
||||
)
|
||||
series["MHA (KV total)"].append(bytes_to_gb(total_mha))
|
||||
series["MHA (KV total)"].append(convert_bytes_to_gb(total_mha))
|
||||
series[
|
||||
f"SWA on MHA (ratio {args.swa_ratio}, W={args.sliding_window_size})"
|
||||
].append(bytes_to_gb(total_mha_swa))
|
||||
].append(convert_bytes_to_gb(total_mha_swa))
|
||||
|
||||
if valid_g4:
|
||||
total_gqa = kv_bytes_total_gqa(
|
||||
total_gqa = calc_kv_bytes_total_gqa(
|
||||
batch_size, L, emb_dim, n_layers, bytes_per_elem, n_kv_groups=kv_groups
|
||||
)
|
||||
total_gqa_swa = kv_bytes_total_gqa_swa(
|
||||
total_gqa_swa = calc_kv_bytes_total_gqa_swa(
|
||||
batch_size,
|
||||
L,
|
||||
emb_dim,
|
||||
@@ -175,10 +175,10 @@ def main():
|
||||
window=args.sliding_window_size,
|
||||
swa_ratio=args.swa_ratio,
|
||||
)
|
||||
series["GQA kv_groups=4 (full)"].append(bytes_to_gb(total_gqa))
|
||||
series["GQA kv_groups=4 (full)"].append(convert_bytes_to_gb(total_gqa))
|
||||
series[
|
||||
f"SWA on GQA kv_groups=4 (ratio {args.swa_ratio}, W={args.sliding_window_size})"
|
||||
].append(bytes_to_gb(total_gqa_swa))
|
||||
].append(convert_bytes_to_gb(total_gqa_swa))
|
||||
|
||||
plt.figure(figsize=(10, 5))
|
||||
x = np.array(context_lengths, dtype=float)
|
||||
|
||||
@@ -14,7 +14,7 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def bytes_convert(n):
|
||||
def convert_bytes(n):
|
||||
gb = n / (1000 ** 3)
|
||||
return f"{gb:,.2f} GB"
|
||||
|
||||
@@ -28,19 +28,19 @@ def get_num_param_matrices(ffn_type):
|
||||
raise ValueError("--ffn_type must be 'gelu' or 'swiglu'")
|
||||
|
||||
|
||||
def ffn_params(emb_dim, hidden_dim, ffn_type):
|
||||
def calc_ffn_params(emb_dim, hidden_dim, ffn_type):
|
||||
return get_num_param_matrices(ffn_type) * emb_dim * hidden_dim
|
||||
|
||||
|
||||
def router_params(emb_dim, num_experts):
|
||||
def calc_router_params(emb_dim, num_experts):
|
||||
return emb_dim * num_experts
|
||||
|
||||
|
||||
def estimate_params_and_hidden(
|
||||
emb_dim, hidden_dim, ffn_type, num_experts, match_dense=False
|
||||
):
|
||||
P_dense = ffn_params(emb_dim, hidden_dim, ffn_type)
|
||||
R = router_params(emb_dim, num_experts)
|
||||
P_dense = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
|
||||
R = calc_router_params(emb_dim, num_experts)
|
||||
|
||||
if match_dense:
|
||||
num_param_matrices = get_num_param_matrices(ffn_type)
|
||||
@@ -52,7 +52,7 @@ def estimate_params_and_hidden(
|
||||
else:
|
||||
moe_hidden_dim = hidden_dim
|
||||
|
||||
per_expert_params = ffn_params(emb_dim, moe_hidden_dim, ffn_type)
|
||||
per_expert_params = calc_ffn_params(emb_dim, moe_hidden_dim, ffn_type)
|
||||
moe_total = num_experts * per_expert_params + R
|
||||
|
||||
return {
|
||||
@@ -110,15 +110,15 @@ def main():
|
||||
|
||||
print("==== Model weights (parameters) ====")
|
||||
print(f"{'Dense FFN params':23}: {res['dense_params']:,} "
|
||||
f"({bytes_convert(res['dense_params'] * bytes_per_elem)})")
|
||||
f"({convert_bytes(res['dense_params'] * bytes_per_elem)})")
|
||||
print(f"{'Per-expert params':23}: {res['per_expert_params']:,} "
|
||||
f"({bytes_convert(res['per_expert_params'] * bytes_per_elem)})")
|
||||
f"({convert_bytes(res['per_expert_params'] * bytes_per_elem)})")
|
||||
print(f"{'Router params':23}: {res['router']:,} "
|
||||
f"({bytes_convert(res['router'] * bytes_per_elem)})")
|
||||
f"({convert_bytes(res['router'] * bytes_per_elem)})")
|
||||
print(f"{'MoE TOTAL params':23}: {res['moe_total']:,} "
|
||||
f"({bytes_convert(res['moe_total'] * bytes_per_elem)})")
|
||||
f"({convert_bytes(res['moe_total'] * bytes_per_elem)})")
|
||||
print(f"{'MoE ACTIVE/Token':23}: {moe_active_params_per_token:,} "
|
||||
f"({bytes_convert(moe_active_params_per_token * bytes_per_elem)})")
|
||||
f"({convert_bytes(moe_active_params_per_token * bytes_per_elem)})")
|
||||
print(f"{'moe_hidden_dim':23}: {res['moe_hidden_dim']}")
|
||||
print()
|
||||
|
||||
|
||||
@@ -6,14 +6,14 @@
|
||||
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
from ffn_moe_memory_estimator import (
|
||||
from memory_estimator_moe import (
|
||||
estimate_params_and_hidden,
|
||||
ffn_params,
|
||||
router_params,
|
||||
calc_ffn_params,
|
||||
calc_router_params,
|
||||
)
|
||||
|
||||
|
||||
def moe_active_and_total(
|
||||
def calc_moe_active_and_total(
|
||||
emb_dim,
|
||||
hidden_dim,
|
||||
ffn_type,
|
||||
@@ -22,8 +22,8 @@ def moe_active_and_total(
|
||||
match_dense=True,
|
||||
):
|
||||
if match_dense:
|
||||
dense_params = ffn_params(emb_dim, hidden_dim, ffn_type)
|
||||
router = router_params(emb_dim, num_experts)
|
||||
dense_params = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
|
||||
router = calc_router_params(emb_dim, num_experts)
|
||||
if dense_params <= router:
|
||||
match_dense = False
|
||||
|
||||
@@ -52,11 +52,11 @@ def plot_active_params_vs_experts(
|
||||
experts = [1, 2, 4, 8, 16, 32, 64, 128, 192, 256, 384, 512]
|
||||
experts = [e for e in experts if e <= max_experts]
|
||||
|
||||
dense_active = ffn_params(emb_dim, hidden_dim, ffn_type)
|
||||
dense_active = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
|
||||
moe_active = []
|
||||
moe_total = []
|
||||
for e in experts:
|
||||
active, total = moe_active_and_total(
|
||||
active, total = calc_moe_active_and_total(
|
||||
emb_dim=emb_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
ffn_type=ffn_type,
|
||||
|
||||
@@ -17,21 +17,21 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
# Full attention (MHA)
|
||||
d_head = emb_dim // n_heads
|
||||
per_layer = batch * context_length * n_heads * d_head * 2 * bytes_per_elem
|
||||
return per_layer * n_layers
|
||||
|
||||
|
||||
def kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
def calc_kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
# Simple Gated DeltaNet (no convolutional mixing)
|
||||
d_head = emb_dim // n_heads
|
||||
per_layer = batch * n_heads * d_head * d_head * bytes_per_elem
|
||||
return per_layer * n_layers
|
||||
|
||||
|
||||
def gb(x):
|
||||
def convert_to_gb(x):
|
||||
return x / 1e9
|
||||
|
||||
|
||||
@@ -52,13 +52,13 @@ def main():
|
||||
|
||||
# 1) Full attention only
|
||||
mha_bytes = np.array([
|
||||
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
for t in ctx
|
||||
], dtype=float)
|
||||
|
||||
# 2) DeltaNet only
|
||||
dnet_bytes_const = kv_bytes_total_deltanet_no_conv(
|
||||
dnet_bytes_const = calc_kv_bytes_total_deltanet_no_conv(
|
||||
args.batch, args.emb_dim, args.n_layers,
|
||||
bytes_per_elem, args.n_heads
|
||||
)
|
||||
@@ -68,17 +68,17 @@ def main():
|
||||
n_mha_layers = args.n_layers / 4
|
||||
n_dnet_layers = args.n_layers - n_mha_layers
|
||||
mix_bytes = np.array([
|
||||
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
+ kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
+ calc_kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
for t in ctx
|
||||
], dtype=float)
|
||||
|
||||
# Convert to GB
|
||||
mha_gb = gb(mha_bytes)
|
||||
dnet_gb = gb(dnet_bytes)
|
||||
mix_gb = gb(mix_bytes)
|
||||
mha_gb = convert_to_gb(mha_bytes)
|
||||
dnet_gb = convert_to_gb(dnet_bytes)
|
||||
mix_gb = convert_to_gb(mix_bytes)
|
||||
|
||||
# Plot
|
||||
fig, ax = plt.subplots(figsize=(7, 4.5))
|
||||
|
||||
Reference in New Issue
Block a user