Optional weight tying for Qwen3 and Llama3.2 pretraining (#949)

* optional weight tying for Qwen3 and Llama3.2

* typo
This commit is contained in:
casinca
2026-01-14 16:07:04 +01:00
committed by GitHub
parent e0dbec3331
commit 9c4be478f8
7 changed files with 17 additions and 9 deletions

View File

@@ -65,7 +65,7 @@ class Llama3Model(nn.Module):
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities
# Reusable utilities
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],