mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Qwen3 From Scratch (#678)
* Qwen3 From Scratch * rev other file * upd * upd * upd * url fixes
This commit is contained in:
committed by
GitHub
parent
e700c66b7a
commit
3d4bce6d57
@@ -19,6 +19,36 @@ import tiktoken
|
||||
import torch
|
||||
|
||||
|
||||
class LitGPTRMSNorm(torch.nn.Module):
|
||||
"""Root Mean Square Layer Normalization.
|
||||
|
||||
From https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
|
||||
Apache License 2.0-Clause License: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
|
||||
|
||||
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
||||
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(size))
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.add_unit_offset = add_unit_offset
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
# NOTE: the original RMSNorm paper implementation is not equivalent
|
||||
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
||||
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
||||
weight = (1 + self.weight) if self.add_unit_offset else self.weight
|
||||
return (x_normed * weight.float()).to(dtype=dtype)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
|
||||
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@@ -179,3 +209,25 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
|
||||
def test_rmsnorm_equivalence():
|
||||
torch.manual_seed(42)
|
||||
|
||||
hidden_size = 64
|
||||
batch_size = 8
|
||||
seq_len = 16
|
||||
|
||||
rms_norm = torch.nn.RMSNorm(hidden_size, eps=1e-6)
|
||||
lit_norm = LitGPTRMSNorm(hidden_size)
|
||||
|
||||
# Sync weights
|
||||
with torch.no_grad():
|
||||
lit_norm.weight.copy_(lit_norm.weight)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
out1 = rms_norm(x)
|
||||
out2 = lit_norm(x)
|
||||
|
||||
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
Reference in New Issue
Block a user