mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 04:23:41 +00:00
Bpe whitespace fixes (#975)
This commit is contained in:
committed by
GitHub
parent
3a7b98a36a
commit
052c2dea4f
@@ -416,23 +416,14 @@
|
||||
" allowed_special (set): A set of special tokens to include.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Preprocess: Replace spaces with \"Ġ\"\n",
|
||||
" # Note that Ġ is a particularity of the GPT-2 BPE implementation\n",
|
||||
" # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n",
|
||||
" # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n",
|
||||
" processed_text = []\n",
|
||||
" for i, char in enumerate(text):\n",
|
||||
" if char == \" \" and i != 0:\n",
|
||||
" processed_text.append(\"Ġ\")\n",
|
||||
" if char != \" \":\n",
|
||||
" processed_text.append(char)\n",
|
||||
" processed_text = \"\".join(processed_text)\n",
|
||||
" # Pre-tokenize training text using the same boundary rules as encode()\n",
|
||||
" tokens = self.pretokenize_text(text)\n",
|
||||
"\n",
|
||||
" # Initialize vocab with unique characters, including \"Ġ\" if present\n",
|
||||
" # Start with the first 256 ASCII characters\n",
|
||||
" unique_chars = [chr(i) for i in range(256)]\n",
|
||||
" unique_chars.extend(\n",
|
||||
" char for char in sorted(set(processed_text))\n",
|
||||
" char for char in sorted({char for token in tokens for char in token})\n",
|
||||
" if char not in unique_chars\n",
|
||||
" )\n",
|
||||
" if \"Ġ\" not in unique_chars:\n",
|
||||
@@ -449,15 +440,18 @@
|
||||
" self.vocab[new_id] = token\n",
|
||||
" self.inverse_vocab[token] = new_id\n",
|
||||
"\n",
|
||||
" # Tokenize the processed_text into token IDs\n",
|
||||
" token_ids = [self.inverse_vocab[char] for char in processed_text]\n",
|
||||
" # Tokenize each pre-token into character IDs\n",
|
||||
" token_id_sequences = [\n",
|
||||
" [self.inverse_vocab[char] for char in token]\n",
|
||||
" for token in tokens\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
|
||||
" for new_id in range(len(self.vocab), vocab_size):\n",
|
||||
" pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n",
|
||||
" pair_id = self.find_freq_pair(token_id_sequences, mode=\"most\")\n",
|
||||
" if pair_id is None:\n",
|
||||
" break\n",
|
||||
" token_ids = self.replace_pair(token_ids, pair_id, new_id)\n",
|
||||
" token_id_sequences = self.replace_pair(token_id_sequences, pair_id, new_id)\n",
|
||||
" self.bpe_merges[pair_id] = new_id\n",
|
||||
"\n",
|
||||
" # Build the vocabulary with merged tokens\n",
|
||||
@@ -581,43 +575,7 @@
|
||||
"\n",
|
||||
" \n",
|
||||
" # ---- Newline and carriage return handling ----\n",
|
||||
" tokens = []\n",
|
||||
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
|
||||
" for part in parts:\n",
|
||||
" if part == \"\":\n",
|
||||
" continue\n",
|
||||
" if part == \"\\r\\n\":\n",
|
||||
" tokens.append(\"\\r\")\n",
|
||||
" tokens.append(\"\\n\")\n",
|
||||
" continue\n",
|
||||
" if part == \"\\r\":\n",
|
||||
" tokens.append(\"\\r\")\n",
|
||||
" continue\n",
|
||||
" if part == \"\\n\":\n",
|
||||
" tokens.append(\"\\n\")\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" # Normal chunk without line breaks:\n",
|
||||
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
|
||||
" # add standalone 'Ġ' for additional spaces\n",
|
||||
" # - If spaces trail the chunk (e.g., before a newline) add\n",
|
||||
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
|
||||
" pending_spaces = 0\n",
|
||||
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
|
||||
" if m.group(1) is not None:\n",
|
||||
" pending_spaces += len(m.group(1))\n",
|
||||
" else:\n",
|
||||
" word = m.group(2)\n",
|
||||
" if pending_spaces > 0:\n",
|
||||
" for _ in range(pending_spaces - 1):\n",
|
||||
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
|
||||
" tokens.append(\"Ġ\" + word) # one leading space\n",
|
||||
" pending_spaces = 0\n",
|
||||
" else:\n",
|
||||
" tokens.append(word)\n",
|
||||
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
|
||||
" for _ in range(pending_spaces):\n",
|
||||
" tokens.append(\"Ġ\")\n",
|
||||
" tokens = self.pretokenize_text(text)\n",
|
||||
" # ---------------------------------------------------------------\n",
|
||||
" \n",
|
||||
" # Map tokens -> ids (BPE if needed)\n",
|
||||
@@ -786,8 +744,53 @@
|
||||
" return self.inverse_vocab.get(token, None)\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def find_freq_pair(token_ids, mode=\"most\"):\n",
|
||||
" pairs = Counter(zip(token_ids, token_ids[1:]))\n",
|
||||
" def pretokenize_text(text):\n",
|
||||
" tokens = []\n",
|
||||
" parts = re.split(r'(\\r\\n|\\r|\\n)', text)\n",
|
||||
" for part in parts:\n",
|
||||
" if part == \"\":\n",
|
||||
" continue\n",
|
||||
" if part == \"\\r\\n\":\n",
|
||||
" tokens.append(\"\\r\")\n",
|
||||
" tokens.append(\"\\n\")\n",
|
||||
" continue\n",
|
||||
" if part == \"\\r\":\n",
|
||||
" tokens.append(\"\\r\")\n",
|
||||
" continue\n",
|
||||
" if part == \"\\n\":\n",
|
||||
" tokens.append(\"\\n\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" # Normal chunk without line breaks:\n",
|
||||
" # - If spaces precede a word, prefix the first word with 'Ġ' and\n",
|
||||
" # add standalone 'Ġ' for additional spaces\n",
|
||||
" # - If spaces trail the chunk (e.g., before a newline) add\n",
|
||||
" # standalone 'Ġ' tokens (tiktoken produces id 220 for 'Ġ')\n",
|
||||
" pending_spaces = 0\n",
|
||||
" for m in re.finditer(r'( +)|(\\S+)', part):\n",
|
||||
" if m.group(1) is not None:\n",
|
||||
" pending_spaces += len(m.group(1))\n",
|
||||
" else:\n",
|
||||
" word = m.group(2)\n",
|
||||
" if pending_spaces > 0:\n",
|
||||
" for _ in range(pending_spaces - 1):\n",
|
||||
" tokens.append(\"Ġ\") # remaining spaces as standalone\n",
|
||||
" tokens.append(\"Ġ\" + word) # one leading space\n",
|
||||
" pending_spaces = 0\n",
|
||||
" else:\n",
|
||||
" tokens.append(word)\n",
|
||||
" # Trailing spaces (no following word): add standalone 'Ġ' tokens\n",
|
||||
" for _ in range(pending_spaces):\n",
|
||||
" tokens.append(\"Ġ\")\n",
|
||||
" return tokens\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def find_freq_pair(token_id_sequences, mode=\"most\"):\n",
|
||||
" pairs = Counter(\n",
|
||||
" pair\n",
|
||||
" for token_ids in token_id_sequences\n",
|
||||
" for pair in zip(token_ids, token_ids[1:])\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if not pairs:\n",
|
||||
" return None\n",
|
||||
@@ -800,20 +803,25 @@
|
||||
" raise ValueError(\"Invalid mode. Choose 'most' or 'least'.\")\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def replace_pair(token_ids, pair_id, new_id):\n",
|
||||
" dq = deque(token_ids)\n",
|
||||
" replaced = []\n",
|
||||
" def replace_pair(token_id_sequences, pair_id, new_id):\n",
|
||||
" replaced_sequences = []\n",
|
||||
"\n",
|
||||
" while dq:\n",
|
||||
" current = dq.popleft()\n",
|
||||
" if dq and (current, dq[0]) == pair_id:\n",
|
||||
" replaced.append(new_id)\n",
|
||||
" # Remove the 2nd token of the pair, 1st was already removed\n",
|
||||
" dq.popleft()\n",
|
||||
" else:\n",
|
||||
" replaced.append(current)\n",
|
||||
" for token_ids in token_id_sequences:\n",
|
||||
" dq = deque(token_ids)\n",
|
||||
" replaced = []\n",
|
||||
"\n",
|
||||
" return replaced"
|
||||
" while dq:\n",
|
||||
" current = dq.popleft()\n",
|
||||
" if dq and (current, dq[0]) == pair_id:\n",
|
||||
" replaced.append(new_id)\n",
|
||||
" # Remove the 2nd token of the pair, 1st was already removed\n",
|
||||
" dq.popleft()\n",
|
||||
" else:\n",
|
||||
" replaced.append(current)\n",
|
||||
"\n",
|
||||
" replaced_sequences.append(replaced)\n",
|
||||
"\n",
|
||||
" return replaced_sequences"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -88,8 +88,14 @@ def test_tokenizer_training(imported_module, verdict_file):
|
||||
assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch."
|
||||
|
||||
input_text = "Jack embraced beauty through art and life."
|
||||
invalid_whitespace_tokens = [
|
||||
tok for tok in tokenizer.vocab.values()
|
||||
if "Ġ" in tok and tok != "Ġ" and not tok.startswith("Ġ")
|
||||
]
|
||||
assert not invalid_whitespace_tokens, "Training should not learn tokens with non-leading Ġ markers."
|
||||
|
||||
token_ids = tokenizer.encode(input_text)
|
||||
assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output."
|
||||
assert token_ids == [74, 361, 310, 109, 98, 420, 397, 100, 300, 428, 116, 121, 519, 699, 299, 808, 534], "Token IDs do not match expected output."
|
||||
|
||||
assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input."
|
||||
|
||||
|
||||
Reference in New Issue
Block a user