mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Reduce Llama 3 RoPE memory requirements (#658)
* Llama3 from scratch improvements * Fix Llama 3 expensive RoPE memory issue * updates * update package * benchmark * remove unused rescale_theta
This commit is contained in:
committed by
GitHub
parent
55e2a0978a
commit
a3c4c33347
@@ -110,12 +110,21 @@ from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
|
||||
|
||||
from llms_from_scratch.appendix_d import find_highest_gradient, train_model
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Llama 3 (Bonus material)
|
||||
|
||||
```python
|
||||
from llms_from_scratch.llama3 import (
|
||||
Llama3Model,
|
||||
Llama3ModelFast,
|
||||
Llama3Tokenizer,
|
||||
ChatFormat,
|
||||
clean_text
|
||||
)
|
||||
```
|
||||
|
||||
(For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
|
||||
|
||||
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
|
||||
|
||||
@@ -15,8 +15,7 @@ from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
LLAMA32_CONFIG_1B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 8192, # Maximum context length to use (reduced to save memory)
|
||||
"orig_context_length": 131_072, # Context length that was used to train the model
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"emb_dim": 2048, # Embedding dimension
|
||||
"n_heads": 32, # Number of attention heads
|
||||
"n_layers": 16, # Number of layers
|
||||
@@ -34,8 +33,7 @@ LLAMA32_CONFIG_1B = {
|
||||
|
||||
LLAMA32_CONFIG_3B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 8192, # Maximum context length to use (reduced to save memory)
|
||||
"orig_context_length": 131_072, # Context length that was used to train the model
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"emb_dim": 3072, # Embedding dimension
|
||||
"n_heads": 24, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
@@ -67,17 +65,6 @@ class Llama3Model(nn.Module):
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusuable utilities
|
||||
self.register_buffer(
|
||||
"mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
|
||||
persistent=False
|
||||
)
|
||||
|
||||
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||
cfg["rope_base"] = rescale_theta(
|
||||
cfg["rope_base"],
|
||||
cfg["orig_context_length"],
|
||||
cfg["context_length"]
|
||||
)
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||
theta_base=cfg["rope_base"],
|
||||
@@ -92,8 +79,11 @@ class Llama3Model(nn.Module):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
num_tokens = x.shape[1]
|
||||
mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.mask, self.cos, self.sin)
|
||||
x = block(x, mask, self.cos, self.sin)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
@@ -281,88 +271,104 @@ def apply_rope(x, cos, sin):
|
||||
return x_rotated.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rescale_theta(theta_old, context_length_old, context_length_new):
|
||||
scaling_factor = context_length_new / context_length_old
|
||||
theta_new = theta_old * scaling_factor
|
||||
return theta_new
|
||||
|
||||
|
||||
##########################################
|
||||
# Tokenizer
|
||||
##########################################
|
||||
|
||||
|
||||
class Llama3Tokenizer:
|
||||
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
|
||||
def __init__(self, model_path):
|
||||
assert os.path.isfile(model_path), f"Model file {model_path} not found"
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(model_path)
|
||||
|
||||
self.special_tokens = {
|
||||
mergeable = load_tiktoken_bpe(model_path)
|
||||
|
||||
# hard-coded from Meta's tokenizer.json
|
||||
self.special = {
|
||||
"<|begin_of_text|>": 128000,
|
||||
"<|end_of_text|>": 128001,
|
||||
"<|start_header_id|>": 128006,
|
||||
"<|end_header_id|>": 128007,
|
||||
"<|eot_id|>": 128009,
|
||||
}
|
||||
self.special_tokens.update({
|
||||
f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()
|
||||
})
|
||||
self.special.update({f"<|reserved_{i}|>": 128002 + i
|
||||
for i in range(256)
|
||||
if 128002 + i not in self.special.values()})
|
||||
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens
|
||||
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
|
||||
r"|[^\r\n\p{L}\p{N}]?\p{L}+"
|
||||
r"|\p{N}{1,3}"
|
||||
r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
|
||||
r"|\s*[\r\n]+"
|
||||
r"|\s+(?!\S)"
|
||||
r"|\s+",
|
||||
mergeable_ranks=mergeable,
|
||||
special_tokens=self.special,
|
||||
)
|
||||
|
||||
def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):
|
||||
def encode(self, text, bos=False, eos=False, allowed_special=set()):
|
||||
ids: list[int] = []
|
||||
|
||||
if bos:
|
||||
tokens = [self.special_tokens["<|begin_of_text|>"]]
|
||||
else:
|
||||
tokens = []
|
||||
|
||||
tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
|
||||
ids.append(self.special_tokens["<|begin_of_text|>"])
|
||||
|
||||
# delegate to underlying tiktoken.Encoding.encode
|
||||
ids.extend(
|
||||
self.model.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
)
|
||||
)
|
||||
if eos:
|
||||
tokens.append(self.special_tokens["<|end_of_text|>"])
|
||||
return tokens
|
||||
ids.append(self.special_tokens["<|end_of_text|>"])
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.model.decode(tokens)
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
return self.model.decode(ids)
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def encode_header(self, message):
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
|
||||
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
def __init__(self, tokenizer: Llama3Tokenizer, *,
|
||||
default_system="You are a helpful assistant."):
|
||||
self.tok = tokenizer
|
||||
self.default_system = default_system
|
||||
|
||||
def encode(self, text, allowed_special=None):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": text
|
||||
}
|
||||
|
||||
tokens = self.encode_header(message)
|
||||
tokens.extend(
|
||||
self.tokenizer.encode(
|
||||
message["content"].strip(),
|
||||
bos=False,
|
||||
eos=False,
|
||||
allowed_special=allowed_special
|
||||
)
|
||||
def _header(self, role):
|
||||
"""Encode <|start_header_id|>role<|end_header_id|>\n\n"""
|
||||
return (
|
||||
[self.tok.special["<|start_header_id|>"]]
|
||||
+ self.tok.encode(role)
|
||||
+ [self.tok.special["<|end_header_id|>"]]
|
||||
+ self.tok.encode("\n\n")
|
||||
)
|
||||
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
|
||||
return tokens
|
||||
|
||||
def decode(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids)
|
||||
def encode(self, user_message, system_message=None, allowed_special=None):
|
||||
sys_msg = system_message if system_message is not None else self.default_system
|
||||
|
||||
ids = [self.tok.special["<|begin_of_text|>"]]
|
||||
|
||||
# system
|
||||
ids += self._header("system")
|
||||
ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
|
||||
ids += [self.tok.special["<|eot_id|>"]]
|
||||
|
||||
# user
|
||||
ids += self._header("user")
|
||||
ids += self.tok.encode(user_message)
|
||||
ids += [self.tok.special["<|eot_id|>"]]
|
||||
|
||||
# assistant header (no content yet)
|
||||
ids += self._header("assistant")
|
||||
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
return self.tok.decode(ids)
|
||||
|
||||
|
||||
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
|
||||
@@ -483,12 +489,6 @@ class Llama3ModelFast(nn.Module):
|
||||
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||
cfg["rope_base"] = rescale_theta(
|
||||
cfg["rope_base"],
|
||||
cfg["orig_context_length"],
|
||||
cfg["context_length"]
|
||||
)
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||
theta_base=cfg["rope_base"],
|
||||
|
||||
@@ -7,7 +7,6 @@ from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.llama3 import (
|
||||
compute_rope_params,
|
||||
apply_rope,
|
||||
rescale_theta,
|
||||
LLAMA32_CONFIG_1B,
|
||||
GroupedQueryAttention,
|
||||
GroupedQueryAttentionFast,
|
||||
@@ -102,23 +101,6 @@ GPT_CONFIG_124M = {
|
||||
}
|
||||
|
||||
|
||||
def test_rescale():
|
||||
|
||||
new_theta = rescale_theta(
|
||||
theta_old=500_000.,
|
||||
context_length_old=131_072,
|
||||
context_length_new=8192
|
||||
)
|
||||
assert new_theta == 31250.
|
||||
|
||||
old_theta = rescale_theta(
|
||||
theta_old=new_theta,
|
||||
context_length_old=8192,
|
||||
context_length_new=131_072
|
||||
)
|
||||
assert old_theta == 500_000.
|
||||
|
||||
|
||||
def test_grouped_query_attention_equivalence():
|
||||
torch.manual_seed(42)
|
||||
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
|
||||
@@ -194,6 +176,6 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
)
|
||||
print("Encoded output text:", out)
|
||||
expect = torch.tensor([
|
||||
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419]
|
||||
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
Reference in New Issue
Block a user