Reduce Llama 3 RoPE memory requirements (#658)

* Llama3 from scratch improvements

* Fix Llama 3 expensive RoPE memory issue

* updates

* update package

* benchmark

* remove unused rescale_theta
This commit is contained in:
Sebastian Raschka
2025-06-12 11:08:02 -05:00
committed by GitHub
parent 55e2a0978a
commit a3c4c33347
9 changed files with 405 additions and 2577 deletions

View File

@@ -110,12 +110,21 @@ from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
from llms_from_scratch.appendix_d import find_highest_gradient, train_model
```
### Llama 3 (Bonus material)
```python
from llms_from_scratch.llama3 import (
Llama3Model,
Llama3ModelFast,
Llama3Tokenizer,
ChatFormat,
clean_text
)
```
(For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).

View File

@@ -15,8 +15,7 @@ from tiktoken.load import load_tiktoken_bpe
LLAMA32_CONFIG_1B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 8192, # Maximum context length to use (reduced to save memory)
"orig_context_length": 131_072, # Context length that was used to train the model
"context_length": 131_072, # Context length that was used to train the model
"emb_dim": 2048, # Embedding dimension
"n_heads": 32, # Number of attention heads
"n_layers": 16, # Number of layers
@@ -34,8 +33,7 @@ LLAMA32_CONFIG_1B = {
LLAMA32_CONFIG_3B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 8192, # Maximum context length to use (reduced to save memory)
"orig_context_length": 131_072, # Context length that was used to train the model
"context_length": 131_072, # Context length that was used to train the model
"emb_dim": 3072, # Embedding dimension
"n_heads": 24, # Number of attention heads
"n_layers": 28, # Number of layers
@@ -67,17 +65,6 @@ class Llama3Model(nn.Module):
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities
self.register_buffer(
"mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
persistent=False
)
if cfg["orig_context_length"] != cfg["context_length"]:
cfg["rope_base"] = rescale_theta(
cfg["rope_base"],
cfg["orig_context_length"],
cfg["context_length"]
)
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],
@@ -92,8 +79,11 @@ class Llama3Model(nn.Module):
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
num_tokens = x.shape[1]
mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
for block in self.trf_blocks:
x = block(x, self.mask, self.cos, self.sin)
x = block(x, mask, self.cos, self.sin)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
@@ -281,88 +271,104 @@ def apply_rope(x, cos, sin):
return x_rotated.to(dtype=x.dtype)
def rescale_theta(theta_old, context_length_old, context_length_new):
scaling_factor = context_length_new / context_length_old
theta_new = theta_old * scaling_factor
return theta_new
##########################################
# Tokenizer
##########################################
class Llama3Tokenizer:
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
def __init__(self, model_path):
assert os.path.isfile(model_path), f"Model file {model_path} not found"
mergeable_ranks = load_tiktoken_bpe(model_path)
if not os.path.isfile(model_path):
raise FileNotFoundError(model_path)
self.special_tokens = {
mergeable = load_tiktoken_bpe(model_path)
# hard-coded from Meta's tokenizer.json
self.special = {
"<|begin_of_text|>": 128000,
"<|end_of_text|>": 128001,
"<|start_header_id|>": 128006,
"<|end_header_id|>": 128007,
"<|eot_id|>": 128009,
}
self.special_tokens.update({
f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()
})
self.special.update({f"<|reserved_{i}|>": 128002 + i
for i in range(256)
if 128002 + i not in self.special.values()})
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
r"|[^\r\n\p{L}\p{N}]?\p{L}+"
r"|\p{N}{1,3}"
r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
r"|\s*[\r\n]+"
r"|\s+(?!\S)"
r"|\s+",
mergeable_ranks=mergeable,
special_tokens=self.special,
)
def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):
def encode(self, text, bos=False, eos=False, allowed_special=set()):
ids: list[int] = []
if bos:
tokens = [self.special_tokens["<|begin_of_text|>"]]
else:
tokens = []
tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
ids.append(self.special_tokens["<|begin_of_text|>"])
# delegate to underlying tiktoken.Encoding.encode
ids.extend(
self.model.encode(
text,
allowed_special=allowed_special,
)
)
if eos:
tokens.append(self.special_tokens["<|end_of_text|>"])
return tokens
ids.append(self.special_tokens["<|end_of_text|>"])
def decode(self, tokens):
return self.model.decode(tokens)
return ids
def decode(self, ids):
return self.model.decode(ids)
class ChatFormat:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def encode_header(self, message):
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def __init__(self, tokenizer: Llama3Tokenizer, *,
default_system="You are a helpful assistant."):
self.tok = tokenizer
self.default_system = default_system
def encode(self, text, allowed_special=None):
message = {
"role": "user",
"content": text
}
tokens = self.encode_header(message)
tokens.extend(
self.tokenizer.encode(
message["content"].strip(),
bos=False,
eos=False,
allowed_special=allowed_special
)
def _header(self, role):
"""Encode <|start_header_id|>role<|end_header_id|>\n\n"""
return (
[self.tok.special["<|start_header_id|>"]]
+ self.tok.encode(role)
+ [self.tok.special["<|end_header_id|>"]]
+ self.tok.encode("\n\n")
)
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
return tokens
def decode(self, token_ids):
return self.tokenizer.decode(token_ids)
def encode(self, user_message, system_message=None, allowed_special=None):
sys_msg = system_message if system_message is not None else self.default_system
ids = [self.tok.special["<|begin_of_text|>"]]
# system
ids += self._header("system")
ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
ids += [self.tok.special["<|eot_id|>"]]
# user
ids += self._header("user")
ids += self.tok.encode(user_message)
ids += [self.tok.special["<|eot_id|>"]]
# assistant header (no content yet)
ids += self._header("assistant")
return ids
def decode(self, ids):
return self.tok.decode(ids)
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
@@ -483,12 +489,6 @@ class Llama3ModelFast(nn.Module):
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
if cfg["orig_context_length"] != cfg["context_length"]:
cfg["rope_base"] = rescale_theta(
cfg["rope_base"],
cfg["orig_context_length"],
cfg["context_length"]
)
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],

View File

@@ -7,7 +7,6 @@ from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.llama3 import (
compute_rope_params,
apply_rope,
rescale_theta,
LLAMA32_CONFIG_1B,
GroupedQueryAttention,
GroupedQueryAttentionFast,
@@ -102,23 +101,6 @@ GPT_CONFIG_124M = {
}
def test_rescale():
new_theta = rescale_theta(
theta_old=500_000.,
context_length_old=131_072,
context_length_new=8192
)
assert new_theta == 31250.
old_theta = rescale_theta(
theta_old=new_theta,
context_length_old=8192,
context_length_new=131_072
)
assert old_theta == 500_000.
def test_grouped_query_attention_equivalence():
torch.manual_seed(42)
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
@@ -194,6 +176,6 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
)
print("Encoded output text:", out)
expect = torch.tensor([
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419]
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
])
assert torch.equal(expect, out)