Readability and code quality improvements (#959)

* Consistent dataset naming

* consistent section headers
This commit is contained in:
Sebastian Raschka
2026-02-17 19:44:56 -05:00
committed by GitHub
parent 7b1f740f74
commit be5e2a3331
48 changed files with 419 additions and 297 deletions

View File

@@ -15,18 +15,18 @@ DTYPE_BYTES = {
}
def bytes_to_gb(n_bytes):
def convert_bytes_to_gb(n_bytes):
return n_bytes / (1000. ** 3)
def kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
n_layers, bytes_per_elem):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
n_layers, bytes_per_elem):
head_dim = emb_dim / n_heads
per_layer = batch * context_length * head_dim * n_heads * 2 * bytes_per_elem
return per_layer * n_layers
def kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
def calc_kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
return batch * context_length * n_layers * latent_dim * bytes_per_elem
@@ -45,27 +45,27 @@ def plot_abs_kv_vs_context_multiple():
mha_gb = []
for L in context_lengths:
total_mha = kv_bytes_total_mha(
total_mha = calc_kv_bytes_total_mha(
batch_size, L, emb_dim, n_heads, n_layers, bytes_per_elem
)
mha_gb.append(bytes_to_gb(total_mha))
mha_gb.append(convert_bytes_to_gb(total_mha))
latent_dims = [1024, 512, 256, 64]
plt.figure()
plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)")
L_ref = context_lengths[-1]
total_mha_ref = kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
total_mha_ref = calc_kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
for latent_dim in latent_dims:
mla_gb = []
for L in context_lengths:
total_mla = kv_bytes_total_mla(
total_mla = calc_kv_bytes_total_mla(
batch_size, L, n_layers, latent_dim, bytes_per_elem
)
mla_gb.append(bytes_to_gb(total_mla))
mla_gb.append(convert_bytes_to_gb(total_mla))
total_mla_ref = kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
total_mla_ref = calc_kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
comp = total_mha_ref / total_mla_ref if total_mla_ref != 0 else float("inf")
plt.plot(context_lengths, mla_gb, marker="o",