Readability and code quality improvements (#959)

* Consistent dataset naming

* consistent section headers
This commit is contained in:
Sebastian Raschka
2026-02-17 19:44:56 -05:00
committed by GitHub
parent 7b1f740f74
commit be5e2a3331
48 changed files with 419 additions and 297 deletions

View File

@@ -17,21 +17,21 @@ DTYPE_BYTES = {
}
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
# Full attention (MHA)
d_head = emb_dim // n_heads
per_layer = batch * context_length * n_heads * d_head * 2 * bytes_per_elem
return per_layer * n_layers
def kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
def calc_kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
# Simple Gated DeltaNet (no convolutional mixing)
d_head = emb_dim // n_heads
per_layer = batch * n_heads * d_head * d_head * bytes_per_elem
return per_layer * n_layers
def gb(x):
def convert_to_gb(x):
return x / 1e9
@@ -52,13 +52,13 @@ def main():
# 1) Full attention only
mha_bytes = np.array([
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads)
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads)
for t in ctx
], dtype=float)
# 2) DeltaNet only
dnet_bytes_const = kv_bytes_total_deltanet_no_conv(
dnet_bytes_const = calc_kv_bytes_total_deltanet_no_conv(
args.batch, args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads
)
@@ -68,17 +68,17 @@ def main():
n_mha_layers = args.n_layers / 4
n_dnet_layers = args.n_layers - n_mha_layers
mix_bytes = np.array([
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
bytes_per_elem, args.n_heads)
+ kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
bytes_per_elem, args.n_heads)
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
bytes_per_elem, args.n_heads)
+ calc_kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
bytes_per_elem, args.n_heads)
for t in ctx
], dtype=float)
# Convert to GB
mha_gb = gb(mha_bytes)
dnet_gb = gb(dnet_bytes)
mix_gb = gb(mix_bytes)
mha_gb = convert_to_gb(mha_bytes)
dnet_gb = convert_to_gb(dnet_bytes)
mix_gb = convert_to_gb(mix_bytes)
# Plot
fig, ax = plt.subplots(figsize=(7, 4.5))