Bpe whitespace fixes (#975)

This commit is contained in:
Sebastian Raschka
2026-03-07 14:56:25 -05:00
committed by GitHub
parent 3a7b98a36a
commit 052c2dea4f
2 changed files with 82 additions and 68 deletions

View File

@@ -416,23 +416,14 @@
" allowed_special (set): A set of special tokens to include.\n", " allowed_special (set): A set of special tokens to include.\n",
" \"\"\"\n", " \"\"\"\n",
"\n", "\n",
" # Preprocess: Replace spaces with \"Ġ\"\n", " # Pre-tokenize training text using the same boundary rules as encode()\n",
" # Note that Ġ is a particularity of the GPT-2 BPE implementation\n", " tokens = self.pretokenize_text(text)\n",
" # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n",
" # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n",
" processed_text = []\n",
" for i, char in enumerate(text):\n",
" if char == \" \" and i != 0:\n",
" processed_text.append(\"Ġ\")\n",
" if char != \" \":\n",
" processed_text.append(char)\n",
" processed_text = \"\".join(processed_text)\n",
"\n", "\n",
" # Initialize vocab with unique characters, including \"Ġ\" if present\n", " # Initialize vocab with unique characters, including \"Ġ\" if present\n",
" # Start with the first 256 ASCII characters\n", " # Start with the first 256 ASCII characters\n",
" unique_chars = [chr(i) for i in range(256)]\n", " unique_chars = [chr(i) for i in range(256)]\n",
" unique_chars.extend(\n", " unique_chars.extend(\n",
" char for char in sorted(set(processed_text))\n", " char for char in sorted({char for token in tokens for char in token})\n",
" if char not in unique_chars\n", " if char not in unique_chars\n",
" )\n", " )\n",
" if \"Ġ\" not in unique_chars:\n", " if \"Ġ\" not in unique_chars:\n",
@@ -449,15 +440,18 @@
" self.vocab[new_id] = token\n", " self.vocab[new_id] = token\n",
" self.inverse_vocab[token] = new_id\n", " self.inverse_vocab[token] = new_id\n",
"\n", "\n",
" # Tokenize the processed_text into token IDs\n", " # Tokenize each pre-token into character IDs\n",
" token_ids = [self.inverse_vocab[char] for char in processed_text]\n", " token_id_sequences = [\n",
" [self.inverse_vocab[char] for char in token]\n",
" for token in tokens\n",
" ]\n",
"\n", "\n",
" # BPE steps 1-3: Repeatedly find and replace frequent pairs\n", " # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
" for new_id in range(len(self.vocab), vocab_size):\n", " for new_id in range(len(self.vocab), vocab_size):\n",
" pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n", " pair_id = self.find_freq_pair(token_id_sequences, mode=\"most\")\n",
" if pair_id is None:\n", " if pair_id is None:\n",
" break\n", " break\n",
" token_ids = self.replace_pair(token_ids, pair_id, new_id)\n", " token_id_sequences = self.replace_pair(token_id_sequences, pair_id, new_id)\n",
" self.bpe_merges[pair_id] = new_id\n", " self.bpe_merges[pair_id] = new_id\n",
"\n", "\n",
" # Build the vocabulary with merged tokens\n", " # Build the vocabulary with merged tokens\n",
@@ -581,43 +575,7 @@
"\n", "\n",
" \n", " \n",
" # ---- Newline and carriage return handling ----\n", " # ---- Newline and carriage return handling ----\n",
" tokens = []\n", " tokens = self.pretokenize_text(text)\n",
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
" for part in parts:\n",
" if part == \"\":\n",
" continue\n",
" if part == \"\\r\\n\":\n",
" tokens.append(\"\\r\")\n",
" tokens.append(\"\\n\")\n",
" continue\n",
" if part == \"\\r\":\n",
" tokens.append(\"\\r\")\n",
" continue\n",
" if part == \"\\n\":\n",
" tokens.append(\"\\n\")\n",
" continue\n",
" \n",
" # Normal chunk without line breaks:\n",
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
" # add standalone 'Ġ' for additional spaces\n",
" # - If spaces trail the chunk (e.g., before a newline) add\n",
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
" pending_spaces = 0\n",
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
" if m.group(1) is not None:\n",
" pending_spaces += len(m.group(1))\n",
" else:\n",
" word = m.group(2)\n",
" if pending_spaces > 0:\n",
" for _ in range(pending_spaces - 1):\n",
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
" tokens.append(\"Ġ\" + word) # one leading space\n",
" pending_spaces = 0\n",
" else:\n",
" tokens.append(word)\n",
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
" for _ in range(pending_spaces):\n",
" tokens.append(\"Ġ\")\n",
" # ---------------------------------------------------------------\n", " # ---------------------------------------------------------------\n",
" \n", " \n",
" # Map tokens -> ids (BPE if needed)\n", " # Map tokens -> ids (BPE if needed)\n",
@@ -786,8 +744,53 @@
" return self.inverse_vocab.get(token, None)\n", " return self.inverse_vocab.get(token, None)\n",
"\n", "\n",
" @staticmethod\n", " @staticmethod\n",
" def find_freq_pair(token_ids, mode=\"most\"):\n", " def pretokenize_text(text):\n",
" pairs = Counter(zip(token_ids, token_ids[1:]))\n", " tokens = []\n",
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
" for part in parts:\n",
" if part == \"\":\n",
" continue\n",
" if part == \"\\r\\n\":\n",
" tokens.append(\"\\r\")\n",
" tokens.append(\"\\n\")\n",
" continue\n",
" if part == \"\\r\":\n",
" tokens.append(\"\\r\")\n",
" continue\n",
" if part == \"\\n\":\n",
" tokens.append(\"\\n\")\n",
" continue\n",
"\n",
" # Normal chunk without line breaks:\n",
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
" # add standalone 'Ġ' for additional spaces\n",
" # - If spaces trail the chunk (e.g., before a newline) add\n",
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
" pending_spaces = 0\n",
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
" if m.group(1) is not None:\n",
" pending_spaces += len(m.group(1))\n",
" else:\n",
" word = m.group(2)\n",
" if pending_spaces > 0:\n",
" for _ in range(pending_spaces - 1):\n",
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
" tokens.append(\"Ġ\" + word) # one leading space\n",
" pending_spaces = 0\n",
" else:\n",
" tokens.append(word)\n",
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
" for _ in range(pending_spaces):\n",
" tokens.append(\"Ġ\")\n",
" return tokens\n",
"\n",
" @staticmethod\n",
" def find_freq_pair(token_id_sequences, mode=\"most\"):\n",
" pairs = Counter(\n",
" pair\n",
" for token_ids in token_id_sequences\n",
" for pair in zip(token_ids, token_ids[1:])\n",
" )\n",
"\n", "\n",
" if not pairs:\n", " if not pairs:\n",
" return None\n", " return None\n",
@@ -800,20 +803,25 @@
" raise ValueError(\"Invalid mode. Choose 'most' or 'least'.\")\n", " raise ValueError(\"Invalid mode. Choose 'most' or 'least'.\")\n",
"\n", "\n",
" @staticmethod\n", " @staticmethod\n",
" def replace_pair(token_ids, pair_id, new_id):\n", " def replace_pair(token_id_sequences, pair_id, new_id):\n",
" dq = deque(token_ids)\n", " replaced_sequences = []\n",
" replaced = []\n",
"\n", "\n",
" while dq:\n", " for token_ids in token_id_sequences:\n",
" current = dq.popleft()\n", " dq = deque(token_ids)\n",
" if dq and (current, dq[0]) == pair_id:\n", " replaced = []\n",
" replaced.append(new_id)\n",
" # Remove the 2nd token of the pair, 1st was already removed\n",
" dq.popleft()\n",
" else:\n",
" replaced.append(current)\n",
"\n", "\n",
" return replaced" " while dq:\n",
" current = dq.popleft()\n",
" if dq and (current, dq[0]) == pair_id:\n",
" replaced.append(new_id)\n",
" # Remove the 2nd token of the pair, 1st was already removed\n",
" dq.popleft()\n",
" else:\n",
" replaced.append(current)\n",
"\n",
" replaced_sequences.append(replaced)\n",
"\n",
" return replaced_sequences"
] ]
}, },
{ {

View File

@@ -88,8 +88,14 @@ def test_tokenizer_training(imported_module, verdict_file):
assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch." assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch."
input_text = "Jack embraced beauty through art and life." input_text = "Jack embraced beauty through art and life."
invalid_whitespace_tokens = [
tok for tok in tokenizer.vocab.values()
if "Ġ" in tok and tok != "Ġ" and not tok.startswith("Ġ")
]
assert not invalid_whitespace_tokens, "Training should not learn tokens with non-leading Ġ markers."
token_ids = tokenizer.encode(input_text) token_ids = tokenizer.encode(input_text)
assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output." assert token_ids == [74, 361, 310, 109, 98, 420, 397, 100, 300, 428, 116, 121, 519, 699, 299, 808, 534], "Token IDs do not match expected output."
assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input." assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."