Readability and code quality improvements (#959)

* Consistent dataset naming

* consistent section headers
This commit is contained in:
Sebastian Raschka
2026-02-17 19:44:56 -05:00
committed by GitHub
parent 7b1f740f74
commit be5e2a3331
48 changed files with 419 additions and 297 deletions

View File

@@ -73,6 +73,7 @@
"id": "53fe99ab-0bcf-4778-a6b5-6db81fb826ef",
"metadata": {},
"source": [
" \n",
"## 4.1 Coding an LLM architecture"
]
},
@@ -323,6 +324,7 @@
"id": "f8332a00-98da-4eb4-b882-922776a89917",
"metadata": {},
"source": [
" \n",
"## 4.2 Normalizing activations with layer normalization"
]
},
@@ -606,6 +608,7 @@
"id": "11190e7d-8c29-4115-824a-e03702f9dd54",
"metadata": {},
"source": [
" \n",
"## 4.3 Implementing a feed forward network with GELU activations"
]
},
@@ -789,6 +792,7 @@
"id": "4ffcb905-53c7-4886-87d2-4464c5fecf89",
"metadata": {},
"source": [
" \n",
"## 4.4 Adding shortcut connections"
]
},
@@ -950,6 +954,7 @@
"id": "cae578ca-e564-42cf-8635-a2267047cdff",
"metadata": {},
"source": [
" \n",
"## 4.5 Connecting attention and linear layers in a transformer block"
]
},
@@ -1068,6 +1073,7 @@
"id": "46618527-15ac-4c32-ad85-6cfea83e006e",
"metadata": {},
"source": [
" \n",
"## 4.6 Coding the GPT model"
]
},
@@ -1332,6 +1338,7 @@
"id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
"metadata": {},
"source": [
" \n",
"## 4.7 Generating text"
]
},
@@ -1519,11 +1526,20 @@
"id": "a35278b6-9e5c-480f-83e5-011a1173648f",
"metadata": {},
"source": [
" \n",
"## Summary and takeaways\n",
"\n",
"- See the [./gpt.py](./gpt.py) script, a self-contained script containing the GPT model we implement in this Jupyter notebook\n",
"- You can find the exercise solutions in [./exercise-solutions.ipynb](./exercise-solutions.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4821ac83-ef84-42c4-a327-32bf2820a8e5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -53,7 +53,8 @@
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
"metadata": {},
"source": [
"# Exercise 4.1: Parameters in the feed forward versus attention module"
" \n",
"## Exercise 4.1: Parameters in the feed forward versus attention module"
]
},
{
@@ -182,7 +183,8 @@
"id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d",
"metadata": {},
"source": [
"# Exercise 4.2: Initialize larger GPT models"
" \n",
"## Exercise 4.2: Initialize larger GPT models"
]
},
{
@@ -329,7 +331,8 @@
"id": "f5f2306e-5dc8-498e-92ee-70ae7ec37ac1",
"metadata": {},
"source": [
"# Exercise 4.3: Using separate dropout parameters"
" \n",
"## Exercise 4.3: Using separate dropout parameters"
]
},
{
@@ -451,7 +454,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -18,13 +18,13 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
def calc_kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
head_dim = math.ceil(emb_dim / n_heads)
per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
return per_layer * n_layers
@@ -58,7 +58,7 @@ def main():
n_kv_heads_mha = cfg["n_heads"]
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
total_mha = kv_bytes_total(
total_mha = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -68,7 +68,7 @@ def main():
bytes_per_elem,
)
total_gqa = kv_bytes_total(
total_gqa = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -91,8 +91,8 @@ def main():
print()
print("==== KV-cache totals across all layers ====")
print(f"MHA total KV cache : {bytes_convert(total_mha)}")
print(f"GQA total KV cache : {bytes_convert(total_gqa)}")
print(f"MHA total KV cache : {convert_bytes(total_mha)}")
print(f"GQA total KV cache : {convert_bytes(total_gqa)}")
print(f"Ratio (MHA / GQA) : {ratio:,.2f}x")
print(f"Savings (GQA vs MHA): {savings*100:,.2f}%")

View File

@@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
# Import from ./memory_estimator.py
from memory_estimator_gqa import kv_bytes_total, DTYPE_BYTES
from memory_estimator_gqa import calc_kv_bytes_total, DTYPE_BYTES
def bytes_convert(n):
@@ -36,7 +36,7 @@ def plot_abs_kv_vs_context_multi_groups():
mha_gb = []
for L in context_lengths:
total_mha = kv_bytes_total(
total_mha = calc_kv_bytes_total(
batch_size, L, emb_dim, n_heads,
n_heads, # MHA: n_kv_heads = n_heads
n_layers, bytes_per_elem
@@ -52,7 +52,7 @@ def plot_abs_kv_vs_context_multi_groups():
n_kv_heads = n_heads // g
gqa_gb = []
for L in context_lengths:
total_gqa = kv_bytes_total(
total_gqa = calc_kv_bytes_total(
batch_size, L, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem
)

View File

@@ -17,20 +17,20 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
def calc_kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
# Generic KV-cache: per-head dim is embed_dim / n_heads, times 2 for K and V
head_dim = math.ceil(emb_dim / n_heads)
per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
return per_layer * n_layers
def mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem):
def calc_mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem):
# Simple MLA (per-token compressed latent)
# bytes ≈ batch × seqlen × n_layers × latent_dim × bytes_per_elem
return batch * context_length * n_layers * latent_dim * bytes_per_elem
@@ -66,7 +66,7 @@ def main():
n_kv_heads_mha = cfg["n_heads"]
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
total_mha = kv_bytes_total(
total_mha = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -76,7 +76,7 @@ def main():
bytes_per_elem,
)
total_gqa = kv_bytes_total(
total_gqa = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -86,7 +86,7 @@ def main():
bytes_per_elem,
)
total_mla = mla_bytes_total(
total_mla = calc_mla_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["n_layers"],
@@ -110,9 +110,9 @@ def main():
print()
print("==== KV-cache totals across all layers ====")
print(f"MHA total KV cache : {bytes_convert(total_mha)}")
print(f"GQA total KV cache : {bytes_convert(total_gqa)}")
print(f"MLA total KV cache : {bytes_convert(total_mla)}")
print(f"MHA total KV cache : {convert_bytes(total_mha)}")
print(f"GQA total KV cache : {convert_bytes(total_gqa)}")
print(f"MLA total KV cache : {convert_bytes(total_mla)}")
print(f"Ratio (MHA / GQA) : {ratio:,.2f}x")
print(f"Savings (GQA vs MHA): {savings*100:,.2f}%")
print(f"Ratio (MHA / MLA) : {ratio_mha_mla:,.2f}x")

View File

@@ -15,18 +15,18 @@ DTYPE_BYTES = {
}
def bytes_to_gb(n_bytes):
def convert_bytes_to_gb(n_bytes):
return n_bytes / (1000. ** 3)
def kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
n_layers, bytes_per_elem):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
n_layers, bytes_per_elem):
head_dim = emb_dim / n_heads
per_layer = batch * context_length * head_dim * n_heads * 2 * bytes_per_elem
return per_layer * n_layers
def kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
def calc_kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
return batch * context_length * n_layers * latent_dim * bytes_per_elem
@@ -45,27 +45,27 @@ def plot_abs_kv_vs_context_multiple():
mha_gb = []
for L in context_lengths:
total_mha = kv_bytes_total_mha(
total_mha = calc_kv_bytes_total_mha(
batch_size, L, emb_dim, n_heads, n_layers, bytes_per_elem
)
mha_gb.append(bytes_to_gb(total_mha))
mha_gb.append(convert_bytes_to_gb(total_mha))
latent_dims = [1024, 512, 256, 64]
plt.figure()
plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)")
L_ref = context_lengths[-1]
total_mha_ref = kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
total_mha_ref = calc_kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
for latent_dim in latent_dims:
mla_gb = []
for L in context_lengths:
total_mla = kv_bytes_total_mla(
total_mla = calc_kv_bytes_total_mla(
batch_size, L, n_layers, latent_dim, bytes_per_elem
)
mla_gb.append(bytes_to_gb(total_mla))
mla_gb.append(convert_bytes_to_gb(total_mla))
total_mla_ref = kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
total_mla_ref = calc_kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
comp = total_mha_ref / total_mla_ref if total_mla_ref != 0 else float("inf")
plt.plot(context_lengths, mla_gb, marker="o",

View File

@@ -17,12 +17,12 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def kv_bytes_per_layer(batch, context_length, head_dim, n_kv_heads, bytes_per_elem):
def calc_kv_bytes_per_layer(batch, context_length, head_dim, n_kv_heads, bytes_per_elem):
# KV = batch * tokens * head_dim * n_kv_heads * 2 (K,V) * bytes
return batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
@@ -64,10 +64,10 @@ def estimate_totals(context_length, sliding_window_size, emb_dim, n_heads, n_lay
L = context_length
# Per-layer costs
per_mha_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_gqa, bytes_per_elem)
per_mha_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_gqa, bytes_per_elem)
per_mha_full = calc_kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_full = calc_kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_gqa, bytes_per_elem)
per_mha_swa = calc_kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_swa = calc_kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_gqa, bytes_per_elem)
# Totals
total_mha_allfull = per_mha_full * n_layers
@@ -140,10 +140,10 @@ def main():
print()
print("==== KV-cache totals across all layers ====")
print(f"MHA KV total : {bytes_convert(res['total_mha_allfull'])}")
print(f"GQA KV total : {bytes_convert(res['total_gqa_allfull'])}")
print(f"MHA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_mha'])}")
print(f"GQA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_gqa'])}")
print(f"MHA KV total : {convert_bytes(res['total_mha_allfull'])}")
print(f"GQA KV total : {convert_bytes(res['total_gqa_allfull'])}")
print(f"MHA + SWA (ratio {args.swa_ratio}) : {convert_bytes(res['total_mixed_mha'])}")
print(f"GQA + SWA (ratio {args.swa_ratio}) : {convert_bytes(res['total_mixed_gqa'])}")
print()

View File

@@ -24,7 +24,7 @@ DTYPE_BYTES = {
}
def bytes_to_gb(n_bytes):
def convert_bytes_to_gb(n_bytes):
return n_bytes / (1000.0 ** 3)
@@ -39,22 +39,22 @@ def parse_ratio(ratio_str):
raise ValueError("--swa_ratio must be in the form 'a:b' with nonnegative integers and a+b>0")
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem):
# For MHA, n_kv_heads = n_heads, which cancels out:
# total = B * L * E * 2 (K,V) * bytes * n_layers
return batch * context_length * emb_dim * 2 * bytes_per_elem * n_layers
def kv_bytes_total_gqa(
def calc_kv_bytes_total_gqa(
batch, context_length, emb_dim, n_layers, bytes_per_elem, n_kv_groups
):
# For GQA, n_kv_heads = n_heads / n_kv_groups
# => scale the MHA total by 1 / n_kv_groups
base = kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem)
base = calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem)
return base / n_kv_groups
def kv_bytes_total_mha_swa(
def calc_kv_bytes_total_mha_swa(
batch, context_length, emb_dim, n_layers, bytes_per_elem, window, swa_ratio
):
# Split layers into SWA vs Full
@@ -63,16 +63,16 @@ def kv_bytes_total_mha_swa(
n_swa_layers = int(round(n_layers * (a / total_blocks)))
n_full_layers = n_layers - n_swa_layers
total_full = kv_bytes_total_mha(
total_full = calc_kv_bytes_total_mha(
batch, context_length, emb_dim, n_full_layers, bytes_per_elem
)
total_swa = kv_bytes_total_mha(
total_swa = calc_kv_bytes_total_mha(
batch, window, emb_dim, n_swa_layers, bytes_per_elem
)
return total_full + total_swa
def kv_bytes_total_gqa_swa(
def calc_kv_bytes_total_gqa_swa(
batch,
context_length,
emb_dim,
@@ -87,7 +87,7 @@ def kv_bytes_total_gqa_swa(
n_swa_layers = int(round(n_layers * (a / total_blocks)))
n_full_layers = n_layers - n_swa_layers
total_full = kv_bytes_total_gqa(
total_full = calc_kv_bytes_total_gqa(
batch,
context_length,
emb_dim,
@@ -95,7 +95,7 @@ def kv_bytes_total_gqa_swa(
bytes_per_elem,
n_kv_groups,
)
total_swa = kv_bytes_total_gqa(
total_swa = calc_kv_bytes_total_gqa(
batch, window, emb_dim, n_swa_layers, bytes_per_elem, n_kv_groups
)
return total_full + total_swa
@@ -144,10 +144,10 @@ def main():
] = []
for L in context_lengths:
total_mha = kv_bytes_total_mha(
total_mha = calc_kv_bytes_total_mha(
batch_size, L, emb_dim, n_layers, bytes_per_elem
)
total_mha_swa = kv_bytes_total_mha_swa(
total_mha_swa = calc_kv_bytes_total_mha_swa(
batch_size,
L,
emb_dim,
@@ -156,16 +156,16 @@ def main():
window=args.sliding_window_size,
swa_ratio=args.swa_ratio,
)
series["MHA (KV total)"].append(bytes_to_gb(total_mha))
series["MHA (KV total)"].append(convert_bytes_to_gb(total_mha))
series[
f"SWA on MHA (ratio {args.swa_ratio}, W={args.sliding_window_size})"
].append(bytes_to_gb(total_mha_swa))
].append(convert_bytes_to_gb(total_mha_swa))
if valid_g4:
total_gqa = kv_bytes_total_gqa(
total_gqa = calc_kv_bytes_total_gqa(
batch_size, L, emb_dim, n_layers, bytes_per_elem, n_kv_groups=kv_groups
)
total_gqa_swa = kv_bytes_total_gqa_swa(
total_gqa_swa = calc_kv_bytes_total_gqa_swa(
batch_size,
L,
emb_dim,
@@ -175,10 +175,10 @@ def main():
window=args.sliding_window_size,
swa_ratio=args.swa_ratio,
)
series["GQA kv_groups=4 (full)"].append(bytes_to_gb(total_gqa))
series["GQA kv_groups=4 (full)"].append(convert_bytes_to_gb(total_gqa))
series[
f"SWA on GQA kv_groups=4 (ratio {args.swa_ratio}, W={args.sliding_window_size})"
].append(bytes_to_gb(total_gqa_swa))
].append(convert_bytes_to_gb(total_gqa_swa))
plt.figure(figsize=(10, 5))
x = np.array(context_lengths, dtype=float)

View File

@@ -14,7 +14,7 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
@@ -28,19 +28,19 @@ def get_num_param_matrices(ffn_type):
raise ValueError("--ffn_type must be 'gelu' or 'swiglu'")
def ffn_params(emb_dim, hidden_dim, ffn_type):
def calc_ffn_params(emb_dim, hidden_dim, ffn_type):
return get_num_param_matrices(ffn_type) * emb_dim * hidden_dim
def router_params(emb_dim, num_experts):
def calc_router_params(emb_dim, num_experts):
return emb_dim * num_experts
def estimate_params_and_hidden(
emb_dim, hidden_dim, ffn_type, num_experts, match_dense=False
):
P_dense = ffn_params(emb_dim, hidden_dim, ffn_type)
R = router_params(emb_dim, num_experts)
P_dense = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
R = calc_router_params(emb_dim, num_experts)
if match_dense:
num_param_matrices = get_num_param_matrices(ffn_type)
@@ -52,7 +52,7 @@ def estimate_params_and_hidden(
else:
moe_hidden_dim = hidden_dim
per_expert_params = ffn_params(emb_dim, moe_hidden_dim, ffn_type)
per_expert_params = calc_ffn_params(emb_dim, moe_hidden_dim, ffn_type)
moe_total = num_experts * per_expert_params + R
return {
@@ -110,15 +110,15 @@ def main():
print("==== Model weights (parameters) ====")
print(f"{'Dense FFN params':23}: {res['dense_params']:,} "
f"({bytes_convert(res['dense_params'] * bytes_per_elem)})")
f"({convert_bytes(res['dense_params'] * bytes_per_elem)})")
print(f"{'Per-expert params':23}: {res['per_expert_params']:,} "
f"({bytes_convert(res['per_expert_params'] * bytes_per_elem)})")
f"({convert_bytes(res['per_expert_params'] * bytes_per_elem)})")
print(f"{'Router params':23}: {res['router']:,} "
f"({bytes_convert(res['router'] * bytes_per_elem)})")
f"({convert_bytes(res['router'] * bytes_per_elem)})")
print(f"{'MoE TOTAL params':23}: {res['moe_total']:,} "
f"({bytes_convert(res['moe_total'] * bytes_per_elem)})")
f"({convert_bytes(res['moe_total'] * bytes_per_elem)})")
print(f"{'MoE ACTIVE/Token':23}: {moe_active_params_per_token:,} "
f"({bytes_convert(moe_active_params_per_token * bytes_per_elem)})")
f"({convert_bytes(moe_active_params_per_token * bytes_per_elem)})")
print(f"{'moe_hidden_dim':23}: {res['moe_hidden_dim']}")
print()

View File

@@ -6,14 +6,14 @@
import argparse
import matplotlib.pyplot as plt
from ffn_moe_memory_estimator import (
from memory_estimator_moe import (
estimate_params_and_hidden,
ffn_params,
router_params,
calc_ffn_params,
calc_router_params,
)
def moe_active_and_total(
def calc_moe_active_and_total(
emb_dim,
hidden_dim,
ffn_type,
@@ -22,8 +22,8 @@ def moe_active_and_total(
match_dense=True,
):
if match_dense:
dense_params = ffn_params(emb_dim, hidden_dim, ffn_type)
router = router_params(emb_dim, num_experts)
dense_params = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
router = calc_router_params(emb_dim, num_experts)
if dense_params <= router:
match_dense = False
@@ -52,11 +52,11 @@ def plot_active_params_vs_experts(
experts = [1, 2, 4, 8, 16, 32, 64, 128, 192, 256, 384, 512]
experts = [e for e in experts if e <= max_experts]
dense_active = ffn_params(emb_dim, hidden_dim, ffn_type)
dense_active = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
moe_active = []
moe_total = []
for e in experts:
active, total = moe_active_and_total(
active, total = calc_moe_active_and_total(
emb_dim=emb_dim,
hidden_dim=hidden_dim,
ffn_type=ffn_type,

View File

@@ -17,21 +17,21 @@ DTYPE_BYTES = {
}
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
# Full attention (MHA)
d_head = emb_dim // n_heads
per_layer = batch * context_length * n_heads * d_head * 2 * bytes_per_elem
return per_layer * n_layers
def kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
def calc_kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
# Simple Gated DeltaNet (no convolutional mixing)
d_head = emb_dim // n_heads
per_layer = batch * n_heads * d_head * d_head * bytes_per_elem
return per_layer * n_layers
def gb(x):
def convert_to_gb(x):
return x / 1e9
@@ -52,13 +52,13 @@ def main():
# 1) Full attention only
mha_bytes = np.array([
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads)
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads)
for t in ctx
], dtype=float)
# 2) DeltaNet only
dnet_bytes_const = kv_bytes_total_deltanet_no_conv(
dnet_bytes_const = calc_kv_bytes_total_deltanet_no_conv(
args.batch, args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads
)
@@ -68,17 +68,17 @@ def main():
n_mha_layers = args.n_layers / 4
n_dnet_layers = args.n_layers - n_mha_layers
mix_bytes = np.array([
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
bytes_per_elem, args.n_heads)
+ kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
bytes_per_elem, args.n_heads)
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
bytes_per_elem, args.n_heads)
+ calc_kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
bytes_per_elem, args.n_heads)
for t in ctx
], dtype=float)
# Convert to GB
mha_gb = gb(mha_bytes)
dnet_gb = gb(dnet_bytes)
mix_gb = gb(mix_bytes)
mha_gb = convert_to_gb(mha_bytes)
dnet_gb = convert_to_gb(dnet_bytes)
mix_gb = convert_to_gb(mix_bytes)
# Plot
fig, ax = plt.subplots(figsize=(7, 4.5))