mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Readability and code quality improvements (#959)
* Consistent dataset naming * consistent section headers
This commit is contained in:
committed by
GitHub
parent
7b1f740f74
commit
be5e2a3331
@@ -17,21 +17,21 @@ DTYPE_BYTES = {
|
||||
}
|
||||
|
||||
|
||||
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
# Full attention (MHA)
|
||||
d_head = emb_dim // n_heads
|
||||
per_layer = batch * context_length * n_heads * d_head * 2 * bytes_per_elem
|
||||
return per_layer * n_layers
|
||||
|
||||
|
||||
def kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
def calc_kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
|
||||
# Simple Gated DeltaNet (no convolutional mixing)
|
||||
d_head = emb_dim // n_heads
|
||||
per_layer = batch * n_heads * d_head * d_head * bytes_per_elem
|
||||
return per_layer * n_layers
|
||||
|
||||
|
||||
def gb(x):
|
||||
def convert_to_gb(x):
|
||||
return x / 1e9
|
||||
|
||||
|
||||
@@ -52,13 +52,13 @@ def main():
|
||||
|
||||
# 1) Full attention only
|
||||
mha_bytes = np.array([
|
||||
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
for t in ctx
|
||||
], dtype=float)
|
||||
|
||||
# 2) DeltaNet only
|
||||
dnet_bytes_const = kv_bytes_total_deltanet_no_conv(
|
||||
dnet_bytes_const = calc_kv_bytes_total_deltanet_no_conv(
|
||||
args.batch, args.emb_dim, args.n_layers,
|
||||
bytes_per_elem, args.n_heads
|
||||
)
|
||||
@@ -68,17 +68,17 @@ def main():
|
||||
n_mha_layers = args.n_layers / 4
|
||||
n_dnet_layers = args.n_layers - n_mha_layers
|
||||
mix_bytes = np.array([
|
||||
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
+ kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
+ calc_kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
|
||||
bytes_per_elem, args.n_heads)
|
||||
for t in ctx
|
||||
], dtype=float)
|
||||
|
||||
# Convert to GB
|
||||
mha_gb = gb(mha_bytes)
|
||||
dnet_gb = gb(dnet_bytes)
|
||||
mix_gb = gb(mix_bytes)
|
||||
mha_gb = convert_to_gb(mha_bytes)
|
||||
dnet_gb = convert_to_gb(dnet_bytes)
|
||||
mix_gb = convert_to_gb(mix_bytes)
|
||||
|
||||
# Plot
|
||||
fig, ax = plt.subplots(figsize=(7, 4.5))
|
||||
|
||||
Reference in New Issue
Block a user