mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Optional weight tying for Qwen3 and Llama3.2 pretraining (#949)
* optional weight tying for Qwen3 and Llama3.2 * typo
This commit is contained in:
@@ -358,7 +358,7 @@
|
|||||||
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
||||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Reusuable utilities\n",
|
" # Reusable utilities\n",
|
||||||
" cos, sin = compute_rope_params(\n",
|
" cos, sin = compute_rope_params(\n",
|
||||||
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
||||||
" theta_base=cfg[\"rope_base\"],\n",
|
" theta_base=cfg[\"rope_base\"],\n",
|
||||||
|
|||||||
@@ -432,7 +432,7 @@
|
|||||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Reusuable utilities\n",
|
" # Reusable utilities\n",
|
||||||
" if cfg[\"head_dim\"] is None:\n",
|
" if cfg[\"head_dim\"] is None:\n",
|
||||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
|
|||||||
@@ -422,7 +422,7 @@
|
|||||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Reusuable utilities\n",
|
" # Reusable utilities\n",
|
||||||
" if cfg[\"head_dim\"] is None:\n",
|
" if cfg[\"head_dim\"] is None:\n",
|
||||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
|
|||||||
@@ -388,7 +388,7 @@
|
|||||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Reusuable utilities\n",
|
" # Reusable utilities\n",
|
||||||
" if cfg[\"head_dim\"] is None:\n",
|
" if cfg[\"head_dim\"] is None:\n",
|
||||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
|
|||||||
@@ -113,7 +113,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
|
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@@ -329,7 +329,11 @@
|
|||||||
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
||||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Reusuable utilities\n",
|
" # Uncomment the following code to tie weights\n",
|
||||||
|
" # self.out_head.weight = self.tok_emb.weight\n",
|
||||||
|
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
|
||||||
|
"\n",
|
||||||
|
" # Reusable utilities\n",
|
||||||
" cos, sin = compute_rope_params(\n",
|
" cos, sin = compute_rope_params(\n",
|
||||||
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
||||||
" theta_base=cfg[\"rope_base\"],\n",
|
" theta_base=cfg[\"rope_base\"],\n",
|
||||||
|
|||||||
@@ -121,7 +121,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
|
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@@ -332,7 +332,11 @@
|
|||||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
||||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Reusuable utilities\n",
|
" # Uncomment the following code to tie weights\n",
|
||||||
|
" # self.out_head.weight = self.tok_emb.weight\n",
|
||||||
|
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
|
||||||
|
"\n",
|
||||||
|
" # Reusable utilities\n",
|
||||||
" if cfg[\"head_dim\"] is None:\n",
|
" if cfg[\"head_dim\"] is None:\n",
|
||||||
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class Llama3Model(nn.Module):
|
|||||||
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||||
|
|
||||||
# Reusuable utilities
|
# Reusable utilities
|
||||||
cos, sin = compute_rope_params(
|
cos, sin = compute_rope_params(
|
||||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||||
theta_base=cfg["rope_base"],
|
theta_base=cfg["rope_base"],
|
||||||
|
|||||||
Reference in New Issue
Block a user