mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Qwen3 From Scratch (#678)
* Qwen3 From Scratch * rev other file * upd * upd * upd * url fixes
This commit is contained in:
committed by
GitHub
parent
e700c66b7a
commit
3d4bce6d57
@@ -19,6 +19,36 @@ import tiktoken
|
||||
import torch
|
||||
|
||||
|
||||
class LitGPTRMSNorm(torch.nn.Module):
|
||||
"""Root Mean Square Layer Normalization.
|
||||
|
||||
From https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
|
||||
Apache License 2.0-Clause License: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
|
||||
|
||||
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
||||
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(size))
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.add_unit_offset = add_unit_offset
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
# NOTE: the original RMSNorm paper implementation is not equivalent
|
||||
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
||||
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
||||
weight = (1 + self.weight) if self.add_unit_offset else self.weight
|
||||
return (x_normed * weight.float()).to(dtype=dtype)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
|
||||
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@@ -179,3 +209,25 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
|
||||
def test_rmsnorm_equivalence():
|
||||
torch.manual_seed(42)
|
||||
|
||||
hidden_size = 64
|
||||
batch_size = 8
|
||||
seq_len = 16
|
||||
|
||||
rms_norm = torch.nn.RMSNorm(hidden_size, eps=1e-6)
|
||||
lit_norm = LitGPTRMSNorm(hidden_size)
|
||||
|
||||
# Sync weights
|
||||
with torch.no_grad():
|
||||
lit_norm.weight.copy_(lit_norm.weight)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
out1 = rms_norm(x)
|
||||
out2 = lit_norm(x)
|
||||
|
||||
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
194
pkg/llms_from_scratch/tests/test_qwen3.py
Normal file
194
pkg/llms_from_scratch/tests/test_qwen3.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.qwen3 import (
|
||||
compute_rope_params,
|
||||
apply_rope,
|
||||
QWEN_CONFIG_06_B,
|
||||
RMSNorm,
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer
|
||||
)
|
||||
|
||||
import importlib
|
||||
import pytest
|
||||
import tiktoken
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Qwen3RMSNorm(nn.Module):
|
||||
# Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py
|
||||
# License: Apache License, Version 2.0 (see file above)
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Qwen3RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
print(input_dtype)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_rope():
|
||||
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
# Settings
|
||||
batch_size = 1
|
||||
context_len = 8192
|
||||
num_heads = 4
|
||||
head_dim = 16
|
||||
rope_theta = 1_000_000
|
||||
|
||||
# Instantiate RoPE parameters
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=head_dim,
|
||||
theta_base=rope_theta,
|
||||
context_length=context_len,
|
||||
)
|
||||
|
||||
# Dummy query and key tensors
|
||||
torch.manual_seed(123)
|
||||
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
queries_rot = apply_rope(queries, cos, sin)
|
||||
keys_rot = apply_rope(keys, cos, sin)
|
||||
|
||||
# Generate reference RoPE via HF
|
||||
class RoPEConfig:
|
||||
rope_type = "qwen3"
|
||||
factor = 1.0
|
||||
dim: int = head_dim
|
||||
rope_theta = 1_000_000
|
||||
max_position_embeddings: int = 8192
|
||||
hidden_size = head_dim * num_heads
|
||||
num_attention_heads = num_heads
|
||||
|
||||
config = RoPEConfig()
|
||||
|
||||
rot_emb = Qwen3RotaryEmbedding(config=config)
|
||||
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||
|
||||
torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
||||
torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
||||
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
||||
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen3_weights_path(tmp_path_factory):
|
||||
"""Creates and saves a deterministic Llama3 model for testing."""
|
||||
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
|
||||
|
||||
if not path.exists():
|
||||
torch.manual_seed(123)
|
||||
model = Qwen3Model(QWEN_CONFIG_06_B)
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [Qwen3Model])
|
||||
def test_gpt_model_variants(ModelClass, qwen3_weights_path):
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(QWEN_CONFIG_06_B)
|
||||
model.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model.eval()
|
||||
|
||||
start_context = "Llamas eat"
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
encoded = tokenizer.encode(start_context)
|
||||
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
|
||||
|
||||
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
|
||||
print("\nInput text:", start_context)
|
||||
print("Encoded input text:", encoded)
|
||||
print("encoded_tensor.shape:", encoded_tensor.shape)
|
||||
|
||||
out = generate_text_simple(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||
)
|
||||
print("Encoded output text:", out)
|
||||
expect = torch.tensor([
|
||||
[43, 2543, 292, 4483, 115206, 459, 43010, 104223, 55553]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
|
||||
def test_rmsnorm_equivalence():
|
||||
torch.manual_seed(42)
|
||||
|
||||
hidden_size = 64
|
||||
batch_size = 8
|
||||
seq_len = 16
|
||||
|
||||
rms_norm = RMSNorm(hidden_size)
|
||||
ref_norm = Qwen3RMSNorm(hidden_size)
|
||||
|
||||
# Sync weights
|
||||
with torch.no_grad():
|
||||
ref_norm.weight.copy_(ref_norm.weight)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
out1 = rms_norm(x)
|
||||
out2 = ref_norm(x)
|
||||
|
||||
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_tokenizer_equivalence():
|
||||
from transformers import AutoTokenizer
|
||||
repo_id = "Qwen/Qwen3-0.6B"
|
||||
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
for states in ((True, True), (False, False)):
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=states[0],
|
||||
add_thinking=states[1]
|
||||
)
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=states[0],
|
||||
enable_thinking=states[1],
|
||||
)
|
||||
assert input_token_ids == input_token_ids_ref, states
|
||||
|
||||
output_text = tokenizer.decode(input_token_ids)
|
||||
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||
assert output_text == out_text_ref, states
|
||||
Reference in New Issue
Block a user