Qwen3 From Scratch (#678)

* Qwen3 From Scratch

* rev other file

* upd

* upd

* upd

* url fixes
This commit is contained in:
Sebastian Raschka
2025-06-19 18:44:38 -05:00
committed by GitHub
parent e700c66b7a
commit 3d4bce6d57
10 changed files with 2640 additions and 6 deletions

View File

@@ -19,6 +19,36 @@ import tiktoken
import torch
class LitGPTRMSNorm(torch.nn.Module):
"""Root Mean Square Layer Normalization.
From https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
Apache License 2.0-Clause License: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
self.add_unit_offset = add_unit_offset
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
# NOTE: the original RMSNorm paper implementation is not equivalent
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
weight = (1 + self.weight) if self.add_unit_offset else self.weight
return (x_normed * weight.float()).to(dtype=dtype)
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
transformers_installed = importlib.util.find_spec("transformers") is not None
@@ -179,3 +209,25 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
])
assert torch.equal(expect, out)
def test_rmsnorm_equivalence():
torch.manual_seed(42)
hidden_size = 64
batch_size = 8
seq_len = 16
rms_norm = torch.nn.RMSNorm(hidden_size, eps=1e-6)
lit_norm = LitGPTRMSNorm(hidden_size)
# Sync weights
with torch.no_grad():
lit_norm.weight.copy_(lit_norm.weight)
x = torch.randn(batch_size, seq_len, hidden_size)
out1 = rms_norm(x)
out2 = lit_norm(x)
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)