mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Bpe whitespace fixes (#975)
This commit is contained in:
committed by
GitHub
parent
3a7b98a36a
commit
052c2dea4f
@@ -416,23 +416,14 @@
|
|||||||
" allowed_special (set): A set of special tokens to include.\n",
|
" allowed_special (set): A set of special tokens to include.\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Preprocess: Replace spaces with \"Ġ\"\n",
|
" # Pre-tokenize training text using the same boundary rules as encode()\n",
|
||||||
" # Note that Ġ is a particularity of the GPT-2 BPE implementation\n",
|
" tokens = self.pretokenize_text(text)\n",
|
||||||
" # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n",
|
|
||||||
" # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n",
|
|
||||||
" processed_text = []\n",
|
|
||||||
" for i, char in enumerate(text):\n",
|
|
||||||
" if char == \" \" and i != 0:\n",
|
|
||||||
" processed_text.append(\"Ġ\")\n",
|
|
||||||
" if char != \" \":\n",
|
|
||||||
" processed_text.append(char)\n",
|
|
||||||
" processed_text = \"\".join(processed_text)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" # Initialize vocab with unique characters, including \"Ġ\" if present\n",
|
" # Initialize vocab with unique characters, including \"Ġ\" if present\n",
|
||||||
" # Start with the first 256 ASCII characters\n",
|
" # Start with the first 256 ASCII characters\n",
|
||||||
" unique_chars = [chr(i) for i in range(256)]\n",
|
" unique_chars = [chr(i) for i in range(256)]\n",
|
||||||
" unique_chars.extend(\n",
|
" unique_chars.extend(\n",
|
||||||
" char for char in sorted(set(processed_text))\n",
|
" char for char in sorted({char for token in tokens for char in token})\n",
|
||||||
" if char not in unique_chars\n",
|
" if char not in unique_chars\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" if \"Ġ\" not in unique_chars:\n",
|
" if \"Ġ\" not in unique_chars:\n",
|
||||||
@@ -449,15 +440,18 @@
|
|||||||
" self.vocab[new_id] = token\n",
|
" self.vocab[new_id] = token\n",
|
||||||
" self.inverse_vocab[token] = new_id\n",
|
" self.inverse_vocab[token] = new_id\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Tokenize the processed_text into token IDs\n",
|
" # Tokenize each pre-token into character IDs\n",
|
||||||
" token_ids = [self.inverse_vocab[char] for char in processed_text]\n",
|
" token_id_sequences = [\n",
|
||||||
|
" [self.inverse_vocab[char] for char in token]\n",
|
||||||
|
" for token in tokens\n",
|
||||||
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
|
" # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
|
||||||
" for new_id in range(len(self.vocab), vocab_size):\n",
|
" for new_id in range(len(self.vocab), vocab_size):\n",
|
||||||
" pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n",
|
" pair_id = self.find_freq_pair(token_id_sequences, mode=\"most\")\n",
|
||||||
" if pair_id is None:\n",
|
" if pair_id is None:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" token_ids = self.replace_pair(token_ids, pair_id, new_id)\n",
|
" token_id_sequences = self.replace_pair(token_id_sequences, pair_id, new_id)\n",
|
||||||
" self.bpe_merges[pair_id] = new_id\n",
|
" self.bpe_merges[pair_id] = new_id\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Build the vocabulary with merged tokens\n",
|
" # Build the vocabulary with merged tokens\n",
|
||||||
@@ -581,43 +575,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # ---- Newline and carriage return handling ----\n",
|
" # ---- Newline and carriage return handling ----\n",
|
||||||
" tokens = []\n",
|
" tokens = self.pretokenize_text(text)\n",
|
||||||
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
|
|
||||||
" for part in parts:\n",
|
|
||||||
" if part == \"\":\n",
|
|
||||||
" continue\n",
|
|
||||||
" if part == \"\\r\\n\":\n",
|
|
||||||
" tokens.append(\"\\r\")\n",
|
|
||||||
" tokens.append(\"\\n\")\n",
|
|
||||||
" continue\n",
|
|
||||||
" if part == \"\\r\":\n",
|
|
||||||
" tokens.append(\"\\r\")\n",
|
|
||||||
" continue\n",
|
|
||||||
" if part == \"\\n\":\n",
|
|
||||||
" tokens.append(\"\\n\")\n",
|
|
||||||
" continue\n",
|
|
||||||
" \n",
|
|
||||||
" # Normal chunk without line breaks:\n",
|
|
||||||
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
|
|
||||||
" # add standalone 'Ġ' for additional spaces\n",
|
|
||||||
" # - If spaces trail the chunk (e.g., before a newline) add\n",
|
|
||||||
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
|
|
||||||
" pending_spaces = 0\n",
|
|
||||||
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
|
|
||||||
" if m.group(1) is not None:\n",
|
|
||||||
" pending_spaces += len(m.group(1))\n",
|
|
||||||
" else:\n",
|
|
||||||
" word = m.group(2)\n",
|
|
||||||
" if pending_spaces > 0:\n",
|
|
||||||
" for _ in range(pending_spaces - 1):\n",
|
|
||||||
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
|
|
||||||
" tokens.append(\"Ġ\" + word) # one leading space\n",
|
|
||||||
" pending_spaces = 0\n",
|
|
||||||
" else:\n",
|
|
||||||
" tokens.append(word)\n",
|
|
||||||
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
|
|
||||||
" for _ in range(pending_spaces):\n",
|
|
||||||
" tokens.append(\"Ġ\")\n",
|
|
||||||
" # ---------------------------------------------------------------\n",
|
" # ---------------------------------------------------------------\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Map tokens -> ids (BPE if needed)\n",
|
" # Map tokens -> ids (BPE if needed)\n",
|
||||||
@@ -786,8 +744,53 @@
|
|||||||
" return self.inverse_vocab.get(token, None)\n",
|
" return self.inverse_vocab.get(token, None)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" @staticmethod\n",
|
" @staticmethod\n",
|
||||||
" def find_freq_pair(token_ids, mode=\"most\"):\n",
|
" def pretokenize_text(text):\n",
|
||||||
" pairs = Counter(zip(token_ids, token_ids[1:]))\n",
|
" tokens = []\n",
|
||||||
|
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
|
||||||
|
" for part in parts:\n",
|
||||||
|
" if part == \"\":\n",
|
||||||
|
" continue\n",
|
||||||
|
" if part == \"\\r\\n\":\n",
|
||||||
|
" tokens.append(\"\\r\")\n",
|
||||||
|
" tokens.append(\"\\n\")\n",
|
||||||
|
" continue\n",
|
||||||
|
" if part == \"\\r\":\n",
|
||||||
|
" tokens.append(\"\\r\")\n",
|
||||||
|
" continue\n",
|
||||||
|
" if part == \"\\n\":\n",
|
||||||
|
" tokens.append(\"\\n\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" # Normal chunk without line breaks:\n",
|
||||||
|
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
|
||||||
|
" # add standalone 'Ġ' for additional spaces\n",
|
||||||
|
" # - If spaces trail the chunk (e.g., before a newline) add\n",
|
||||||
|
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
|
||||||
|
" pending_spaces = 0\n",
|
||||||
|
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
|
||||||
|
" if m.group(1) is not None:\n",
|
||||||
|
" pending_spaces += len(m.group(1))\n",
|
||||||
|
" else:\n",
|
||||||
|
" word = m.group(2)\n",
|
||||||
|
" if pending_spaces > 0:\n",
|
||||||
|
" for _ in range(pending_spaces - 1):\n",
|
||||||
|
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
|
||||||
|
" tokens.append(\"Ġ\" + word) # one leading space\n",
|
||||||
|
" pending_spaces = 0\n",
|
||||||
|
" else:\n",
|
||||||
|
" tokens.append(word)\n",
|
||||||
|
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
|
||||||
|
" for _ in range(pending_spaces):\n",
|
||||||
|
" tokens.append(\"Ġ\")\n",
|
||||||
|
" return tokens\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def find_freq_pair(token_id_sequences, mode=\"most\"):\n",
|
||||||
|
" pairs = Counter(\n",
|
||||||
|
" pair\n",
|
||||||
|
" for token_ids in token_id_sequences\n",
|
||||||
|
" for pair in zip(token_ids, token_ids[1:])\n",
|
||||||
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if not pairs:\n",
|
" if not pairs:\n",
|
||||||
" return None\n",
|
" return None\n",
|
||||||
@@ -800,20 +803,25 @@
|
|||||||
" raise ValueError(\"Invalid mode. Choose 'most' or 'least'.\")\n",
|
" raise ValueError(\"Invalid mode. Choose 'most' or 'least'.\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" @staticmethod\n",
|
" @staticmethod\n",
|
||||||
" def replace_pair(token_ids, pair_id, new_id):\n",
|
" def replace_pair(token_id_sequences, pair_id, new_id):\n",
|
||||||
" dq = deque(token_ids)\n",
|
" replaced_sequences = []\n",
|
||||||
" replaced = []\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" while dq:\n",
|
" for token_ids in token_id_sequences:\n",
|
||||||
" current = dq.popleft()\n",
|
" dq = deque(token_ids)\n",
|
||||||
" if dq and (current, dq[0]) == pair_id:\n",
|
" replaced = []\n",
|
||||||
" replaced.append(new_id)\n",
|
|
||||||
" # Remove the 2nd token of the pair, 1st was already removed\n",
|
|
||||||
" dq.popleft()\n",
|
|
||||||
" else:\n",
|
|
||||||
" replaced.append(current)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" return replaced"
|
" while dq:\n",
|
||||||
|
" current = dq.popleft()\n",
|
||||||
|
" if dq and (current, dq[0]) == pair_id:\n",
|
||||||
|
" replaced.append(new_id)\n",
|
||||||
|
" # Remove the 2nd token of the pair, 1st was already removed\n",
|
||||||
|
" dq.popleft()\n",
|
||||||
|
" else:\n",
|
||||||
|
" replaced.append(current)\n",
|
||||||
|
"\n",
|
||||||
|
" replaced_sequences.append(replaced)\n",
|
||||||
|
"\n",
|
||||||
|
" return replaced_sequences"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -88,8 +88,14 @@ def test_tokenizer_training(imported_module, verdict_file):
|
|||||||
assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch."
|
assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch."
|
||||||
|
|
||||||
input_text = "Jack embraced beauty through art and life."
|
input_text = "Jack embraced beauty through art and life."
|
||||||
|
invalid_whitespace_tokens = [
|
||||||
|
tok for tok in tokenizer.vocab.values()
|
||||||
|
if "Ġ" in tok and tok != "Ġ" and not tok.startswith("Ġ")
|
||||||
|
]
|
||||||
|
assert not invalid_whitespace_tokens, "Training should not learn tokens with non-leading Ġ markers."
|
||||||
|
|
||||||
token_ids = tokenizer.encode(input_text)
|
token_ids = tokenizer.encode(input_text)
|
||||||
assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output."
|
assert token_ids == [74, 361, 310, 109, 98, 420, 397, 100, 300, 428, 116, 121, 519, 699, 299, 808, 534], "Token IDs do not match expected output."
|
||||||
|
|
||||||
assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."
|
assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user