From f63f04d8d5e80108038f1198793f76c9fa13b9fd Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sat, 8 Mar 2025 17:21:30 -0600 Subject: [PATCH] Fix BPE bonus materials (#561) * Fix BPE bonus materials * fix bpe implementation * update * Add 'Hello, world. Is this-- a test?' test case * update link to test file * update path handling * update path handling * fix pytest paths --- .github/workflows/basic-tests-linux-uv.yml | 6 + .gitignore | 8 +- .../compare-bpe-tiktoken.ipynb | 36 ++-- .../bpe-from-scratch.ipynb | 197 +++++++++++------- ch02/05_bpe-from-scratch/tests/tests.py | 147 +++++++++++++ 5 files changed, 307 insertions(+), 87 deletions(-) create mode 100644 ch02/05_bpe-from-scratch/tests/tests.py diff --git a/.github/workflows/basic-tests-linux-uv.yml b/.github/workflows/basic-tests-linux-uv.yml index 4b43190..982e8f9 100644 --- a/.github/workflows/basic-tests-linux-uv.yml +++ b/.github/workflows/basic-tests-linux-uv.yml @@ -60,3 +60,9 @@ jobs: pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb + + - name: Test Selected Bonus Materials + shell: bash + run: | + source .venv/bin/activate + pytest ch02/05_bpe-from-scratch/tests/tests.py diff --git a/.gitignore b/.gitignore index d22fe5d..7f0c181 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ + # Configs and keys ch05/07_gpt_to_llama/config.json ch07/02_dataset-utilities/config.json @@ -63,6 +64,8 @@ ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth ch07/01_main-chapter-code/gpt2/ # Datasets +the-verdict.txt + appendix-E/01_main-chapter-code/sms_spam_collection.zip appendix-E/01_main-chapter-code/sms_spam_collection appendix-E/01_main-chapter-code/train.csv @@ -70,6 +73,7 @@ appendix-E/01_main-chapter-code/test.csv appendix-E/01_main-chapter-code/validation.csv ch02/01_main-chapter-code/number-data.txt +ch02/05_bpe-from-scratch/the-verdict.txt ch05/03_bonus_pretraining_on_gutenberg/gutenberg ch05/03_bonus_pretraining_on_gutenberg/gutenberg_preprocessed @@ -107,7 +111,9 @@ ch02/05_bpe-from-scratch/bpe_merges.txt ch02/05_bpe-from-scratch/encoder.json ch02/05_bpe-from-scratch/vocab.bpe ch02/05_bpe-from-scratch/vocab.json - +encoder.json +vocab.bpe +vocab.json # Other ch0?/0?_user_interface/.chainlit/ diff --git a/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb b/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb index 8e144bf..75f335f 100644 --- a/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb +++ b/ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb @@ -67,7 +67,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "tiktoken version: 0.7.0\n" + "tiktoken version: 0.9.0\n" ] } ], @@ -180,8 +180,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Fetching encoder.json: 1.04Mit [00:00, 4.13Mit/s] \n", - "Fetching vocab.bpe: 457kit [00:00, 2.56Mit/s] \n" + "Fetching encoder.json: 1.04Mit [00:00, 3.69Mit/s] \n", + "Fetching vocab.bpe: 457kit [00:00, 2.53Mit/s] \n" ] } ], @@ -256,10 +256,18 @@ "id": "e9077bf4-f91f-42ad-ab76-f3d89128510e", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "data": { "text/plain": [ - "'4.48.0'" + "'4.49.0'" ] }, "execution_count": 12, @@ -423,7 +431,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[1544, 18798, 11, 995, 13, 1148, 256, 5303, 82, 438, 257, 1332, 30]\n" + "[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]\n" ] } ], @@ -451,7 +459,7 @@ "metadata": {}, "outputs": [], "source": [ - "with open('../01_main-chapter-code/the-verdict.txt', 'r', encoding='utf-8') as f:\n", + "with open(\"../01_main-chapter-code/the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n", " raw_text = f.read()" ] }, @@ -473,7 +481,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.39 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.84 ms ± 9.83 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -499,7 +507,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.08 ms ± 5.99 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + "901 μs ± 6.27 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], @@ -532,7 +540,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "10.2 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "11 ms ± 94.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -550,7 +558,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "10 ms ± 36.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "10.8 ms ± 180 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -575,7 +583,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.79 ms ± 48.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.66 ms ± 3.67 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -593,7 +601,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.83 ms ± 58.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.77 ms ± 49.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -619,7 +627,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.59 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + "9.37 ms ± 50.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -644,7 +652,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb index 859cc78..6dc17e0 100644 --- a/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb +++ b/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb @@ -382,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "3e4a15ec-2667-4f56-b7c1-34e8071b621d", "metadata": {}, "outputs": [], @@ -401,6 +401,10 @@ " # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}\n", " self.bpe_merges = {}\n", "\n", + " # For the official OpenAI GPT-2 merges, use a rank dict:\n", + " # of form {(string_A, string_B): rank}, where lower rank = higher priority\n", + " self.bpe_ranks = {}\n", + "\n", " def train(self, text, vocab_size, allowed_special={\"<|endoftext|>\"}):\n", " \"\"\"\n", " Train the BPE tokenizer from scratch.\n", @@ -411,7 +415,7 @@ " allowed_special (set): A set of special tokens to include.\n", " \"\"\"\n", "\n", - " # Preprocess: Replace spaces with 'Ġ'\n", + " # Preprocess: Replace spaces with \"Ġ\"\n", " # Note that Ġ is a particularity of the GPT-2 BPE implementation\n", " # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n", " # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n", @@ -423,18 +427,16 @@ " processed_text.append(char)\n", " processed_text = \"\".join(processed_text)\n", "\n", - " # Initialize vocab with unique characters, including 'Ġ' if present\n", + " # Initialize vocab with unique characters, including \"Ġ\" if present\n", " # Start with the first 256 ASCII characters\n", " unique_chars = [chr(i) for i in range(256)]\n", - "\n", - " # Extend unique_chars with characters from processed_text that are not already included\n", - " unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)\n", - "\n", - " # Optionally, ensure 'Ġ' is included if it is relevant to your text processing\n", + " unique_chars.extend(\n", + " char for char in sorted(set(processed_text))\n", + " if char not in unique_chars\n", + " )\n", " if \"Ġ\" not in unique_chars:\n", " unique_chars.append(\"Ġ\")\n", "\n", - " # Now create the vocab and inverse vocab dictionaries\n", " self.vocab = {i: char for i, char in enumerate(unique_chars)}\n", " self.inverse_vocab = {char: i for i, char in self.vocab.items()}\n", "\n", @@ -452,7 +454,7 @@ " # BPE steps 1-3: Repeatedly find and replace frequent pairs\n", " for new_id in range(len(self.vocab), vocab_size):\n", " pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n", - " if pair_id is None: # No more pairs to merge. Stopping training.\n", + " if pair_id is None:\n", " break\n", " token_ids = self.replace_pair(token_ids, pair_id, new_id)\n", " self.bpe_merges[pair_id] = new_id\n", @@ -492,29 +494,24 @@ " self.inverse_vocab[\"\\n\"] = newline_token_id\n", " self.vocab[newline_token_id] = \"\\n\"\n", "\n", - " # Load BPE merges\n", + " # Load GPT-2 merges and store them with an assigned \"rank\"\n", + " self.bpe_ranks = {} # reset ranks\n", " with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n", " lines = file.readlines()\n", - " # Skip header line if present\n", " if lines and lines[0].startswith(\"#\"):\n", " lines = lines[1:]\n", "\n", + " rank = 0\n", " for line in lines:\n", " pair = tuple(line.strip().split())\n", " if len(pair) == 2:\n", " token1, token2 = pair\n", + " # If token1 or token2 not in vocab, skip\n", " if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n", - " token_id1 = self.inverse_vocab[token1]\n", - " token_id2 = self.inverse_vocab[token2]\n", - " merged_token = token1 + token2\n", - " if merged_token in self.inverse_vocab:\n", - " merged_token_id = self.inverse_vocab[merged_token]\n", - " self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n", - " # print(f\"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})\")\n", - " else:\n", - " print(f\"Merged token '{merged_token}' not found in vocab. Skipping.\")\n", + " self.bpe_ranks[(token1, token2)] = rank\n", + " rank += 1\n", " else:\n", - " print(f\"Skipping pair {pair} as one of the tokens is not in the vocabulary.\")\n", + " print(f\"Skipping pair {pair} as one token is not in the vocabulary.\")\n", "\n", " def encode(self, text):\n", " \"\"\"\n", @@ -540,7 +537,7 @@ " else:\n", " tokens.append(word)\n", " else:\n", - " # Prefix words in the middle of a line with 'Ġ'\n", + " # Prefix words in the middle of a line with \"Ġ\"\n", " tokens.append(\"Ġ\" + word)\n", "\n", " token_ids = []\n", @@ -571,28 +568,74 @@ " missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]\n", " raise ValueError(f\"Characters not found in vocab: {missing_chars}\")\n", "\n", - " can_merge = True\n", - " while can_merge and len(token_ids) > 1:\n", - " can_merge = False\n", - " new_tokens = []\n", - " i = 0\n", - " while i < len(token_ids) - 1:\n", - " pair = (token_ids[i], token_ids[i + 1])\n", - " if pair in self.bpe_merges:\n", - " merged_token_id = self.bpe_merges[pair]\n", - " new_tokens.append(merged_token_id)\n", - " # Uncomment for educational purposes:\n", - " # print(f\"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')\")\n", - " i += 2 # Skip the next token as it's merged\n", - " can_merge = True\n", - " else:\n", + " # If we haven't loaded OpenAI's GPT-2 merges, use my approach\n", + " if not self.bpe_ranks:\n", + " can_merge = True\n", + " while can_merge and len(token_ids) > 1:\n", + " can_merge = False\n", + " new_tokens = []\n", + " i = 0\n", + " while i < len(token_ids) - 1:\n", + " pair = (token_ids[i], token_ids[i + 1])\n", + " if pair in self.bpe_merges:\n", + " merged_token_id = self.bpe_merges[pair]\n", + " new_tokens.append(merged_token_id)\n", + " # Uncomment for educational purposes:\n", + " # print(f\"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')\")\n", + " i += 2 # Skip the next token as it's merged\n", + " can_merge = True\n", + " else:\n", + " new_tokens.append(token_ids[i])\n", + " i += 1\n", + " if i < len(token_ids):\n", " new_tokens.append(token_ids[i])\n", - " i += 1\n", - " if i < len(token_ids):\n", - " new_tokens.append(token_ids[i])\n", - " token_ids = new_tokens\n", + " token_ids = new_tokens\n", + " return token_ids\n", "\n", - " return token_ids\n", + " # Otherwise, do GPT-2-style merging with the ranks:\n", + " # 1) Convert token_ids back to string \"symbols\" for each ID\n", + " symbols = [self.vocab[id_num] for id_num in token_ids]\n", + "\n", + " # Repeatedly merge all occurrences of the lowest-rank pair\n", + " while True:\n", + " # Collect all adjacent pairs\n", + " pairs = set(zip(symbols, symbols[1:]))\n", + " if not pairs:\n", + " break\n", + "\n", + " # Find the pair with the best (lowest) rank\n", + " min_rank = 1_000_000_000\n", + " bigram = None\n", + " for p in pairs:\n", + " r = self.bpe_ranks.get(p, 1_000_000_000)\n", + " if r < min_rank:\n", + " min_rank = r\n", + " bigram = p\n", + "\n", + " # If no valid ranked pair is present, we're done\n", + " if bigram is None or bigram not in self.bpe_ranks:\n", + " break\n", + "\n", + " # Merge all occurrences of that pair\n", + " first, second = bigram\n", + " new_symbols = []\n", + " i = 0\n", + " while i < len(symbols):\n", + " # If we see (first, second) at position i, merge them\n", + " if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:\n", + " new_symbols.append(first + second) # merged symbol\n", + " i += 2\n", + " else:\n", + " new_symbols.append(symbols[i])\n", + " i += 1\n", + " symbols = new_symbols\n", + "\n", + " if len(symbols) == 1:\n", + " break\n", + "\n", + " # Finally, convert merged symbols back to IDs\n", + " merged_ids = [self.inverse_vocab[sym] for sym in symbols]\n", + " return merged_ids\n", "\n", " def decode(self, token_ids):\n", " \"\"\"\n", @@ -738,22 +781,49 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "4d197cad-ed10-4a42-b01c-a763859781fb", + "execution_count": 25, + "id": "51872c08-e01b-40c3-a8a0-e8d6a773e3df", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the-verdict.txt already exists in ./the-verdict.txt\n" + ] + } + ], "source": [ "import os\n", "import urllib.request\n", "\n", - "if not os.path.exists(\"../01_main-chapter-code/the-verdict.txt\"):\n", - " url = (\"https://raw.githubusercontent.com/rasbt/\"\n", - " \"LLMs-from-scratch/main/ch02/01_main-chapter-code/\"\n", - " \"the-verdict.txt\")\n", - " file_path = \"../01_main-chapter-code/the-verdict.txt\"\n", - " urllib.request.urlretrieve(url, file_path)\n", + "def download_file_if_absent(url, filename, search_dirs):\n", + " for directory in search_dirs:\n", + " file_path = os.path.join(directory, filename)\n", + " if os.path.exists(file_path):\n", + " print(f\"{filename} already exists in {file_path}\")\n", + " return file_path\n", "\n", - "with open(\"../01_main-chapter-code/the-verdict.txt\", \"r\", encoding=\"utf-8\") as f: # added ../01_main-chapter-code/\n", + " target_path = os.path.join(search_dirs[0], filename)\n", + " try:\n", + " with urllib.request.urlopen(url) as response, open(target_path, \"wb\") as out_file:\n", + " out_file.write(response.read())\n", + " print(f\"Downloaded {filename} to {target_path}\")\n", + " except Exception as e:\n", + " print(f\"Failed to download {filename}. Error: {e}\")\n", + " return target_path\n", + "\n", + "verdict_path = download_file_if_absent(\n", + " url=(\n", + " \"https://raw.githubusercontent.com/rasbt/\"\n", + " \"LLMs-from-scratch/main/ch02/01_main-chapter-code/\"\n", + " \"the-verdict.txt\"\n", + " ),\n", + " filename=\"the-verdict.txt\",\n", + " search_dirs=\".\"\n", + ")\n", + "\n", + "with open(verdict_path, \"r\", encoding=\"utf-8\") as f: # added ../01_main-chapter-code/\n", " text = f.read()" ] }, @@ -1168,24 +1238,7 @@ } ], "source": [ - "import os\n", - "import urllib.request\n", - "\n", - "def download_file_if_absent(url, filename, search_dirs):\n", - " for directory in search_dirs:\n", - " file_path = os.path.join(directory, filename)\n", - " if os.path.exists(file_path):\n", - " print(f\"{filename} already exists in {file_path}\")\n", - " return file_path\n", - "\n", - " target_path = os.path.join(search_dirs[0], filename)\n", - " try:\n", - " with urllib.request.urlopen(url) as response, open(target_path, \"wb\") as out_file:\n", - " out_file.write(response.read())\n", - " print(f\"Downloaded {filename} to {target_path}\")\n", - " except Exception as e:\n", - " print(f\"Failed to download {filename}. Error: {e}\")\n", - " return target_path\n", + "# Download files if not already present in this directory\n", "\n", "# Define the directories to search and the files to download\n", "search_directories = [\".\", \"../02_bonus_bytepair-encoder/gpt2_model/\"]\n", @@ -1351,7 +1404,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/ch02/05_bpe-from-scratch/tests/tests.py b/ch02/05_bpe-from-scratch/tests/tests.py new file mode 100644 index 0000000..97ee010 --- /dev/null +++ b/ch02/05_bpe-from-scratch/tests/tests.py @@ -0,0 +1,147 @@ +import os +import sys +import io +import nbformat +import types +import pytest + +import tiktoken + + +def import_definitions_from_notebook(fullname, names): + """Loads function definitions from a Jupyter notebook file into a module.""" + path = os.path.join(os.path.dirname(__file__), "..", fullname + ".ipynb") + path = os.path.normpath(path) + + if not os.path.exists(path): + raise FileNotFoundError(f"Notebook file not found at: {path}") + + with io.open(path, "r", encoding="utf-8") as f: + nb = nbformat.read(f, as_version=4) + + mod = types.ModuleType(fullname) + sys.modules[fullname] = mod + + # Execute all code cells to capture dependencies + for cell in nb.cells: + if cell.cell_type == "code": + exec(cell.source, mod.__dict__) + + # Ensure required names are in module + missing_names = [name for name in names if name not in mod.__dict__] + if missing_names: + raise ImportError(f"Missing definitions in notebook: {missing_names}") + + return mod + + +@pytest.fixture(scope="module") +def imported_module(): + fullname = "bpe-from-scratch" + names = ["BPETokenizerSimple", "download_file_if_absent"] + return import_definitions_from_notebook(fullname, names) + + +@pytest.fixture(scope="module") +def gpt2_files(imported_module): + """Fixture to handle downloading GPT-2 files.""" + download_file_if_absent = getattr(imported_module, "download_file_if_absent", None) + + search_directories = [".", "../02_bonus_bytepair-encoder/gpt2_model/"] + files_to_download = { + "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe": "vocab.bpe", + "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json": "encoder.json" + } + paths = {filename: download_file_if_absent(url, filename, search_directories) + for url, filename in files_to_download.items()} + + return paths + + +def test_tokenizer_training(imported_module, gpt2_files): + BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None) + download_file_if_absent = getattr(imported_module, "download_file_if_absent", None) + + tokenizer = BPETokenizerSimple() + verdict_path = download_file_if_absent( + url=( + "https://raw.githubusercontent.com/rasbt/" + "LLMs-from-scratch/main/ch02/01_main-chapter-code/" + "the-verdict.txt" + ), + filename="the-verdict.txt", + search_dirs="." + ) + + with open(verdict_path, "r", encoding="utf-8") as f: # added ../01_main-chapter-code/ + text = f.read() + + tokenizer.train(text, vocab_size=1000, allowed_special={"<|endoftext|>"}) + assert len(tokenizer.vocab) == 1000, "Tokenizer vocabulary size mismatch." + assert len(tokenizer.bpe_merges) == 742, "Tokenizer BPE merges count mismatch." + + input_text = "Jack embraced beauty through art and life." + token_ids = tokenizer.encode(input_text) + assert token_ids == [424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 287, 466, 256, 326, 972, 46], "Token IDs do not match expected output." + + assert tokenizer.decode(token_ids) == input_text, "Decoded text does not match the original input." + + tokenizer.save_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt") + tokenizer2 = BPETokenizerSimple() + tokenizer2.load_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt") + assert tokenizer2.decode(token_ids) == input_text, "Decoded text mismatch after reloading tokenizer." + + +def test_gpt2_tokenizer_openai_simple(imported_module, gpt2_files): + BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None) + + tokenizer_gpt2 = BPETokenizerSimple() + tokenizer_gpt2.load_vocab_and_merges_from_openai( + vocab_path=gpt2_files["encoder.json"], bpe_merges_path=gpt2_files["vocab.bpe"] + ) + + assert len(tokenizer_gpt2.vocab) == 50257, "GPT-2 tokenizer vocabulary size mismatch." + + input_text = "This is some text" + token_ids = tokenizer_gpt2.encode(input_text) + assert token_ids == [1212, 318, 617, 2420], "Tokenized output does not match expected GPT-2 encoding." + + +def test_gpt2_tokenizer_openai_edgecases(imported_module, gpt2_files): + BPETokenizerSimple = getattr(imported_module, "BPETokenizerSimple", None) + + tokenizer_gpt2 = BPETokenizerSimple() + tokenizer_gpt2.load_vocab_and_merges_from_openai( + vocab_path=gpt2_files["encoder.json"], bpe_merges_path=gpt2_files["vocab.bpe"] + ) + tik_tokenizer = tiktoken.get_encoding("gpt2") + + test_cases = [ + ("Hello,", [15496, 11]), + ("Implementations", [3546, 26908, 602]), + ("asdf asdfasdf a!!, @aba 9asdf90asdfk", [292, 7568, 355, 7568, 292, 7568, 257, 3228, 11, 2488, 15498, 860, 292, 7568, 3829, 292, 7568, 74]), + ("Hello, world. Is this-- a test?", [15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]) + ] + + errors = [] + + for input_text, expected_tokens in test_cases: + tik_tokens = tik_tokenizer.encode(input_text) + gpt2_tokens = tokenizer_gpt2.encode(input_text) + + print(f"Text: {input_text}") + print(f"Expected Tokens: {expected_tokens}") + print(f"tiktoken Output: {tik_tokens}") + print(f"BPETokenizerSimple Output: {gpt2_tokens}") + print("-" * 40) + + if tik_tokens != expected_tokens: + errors.append(f"Tiktokenized output does not match expected GPT-2 encoding for '{input_text}'.\n" + f"Expected: {expected_tokens}, Got: {tik_tokens}") + + if gpt2_tokens != expected_tokens: + errors.append(f"Tokenized output does not match expected GPT-2 encoding for '{input_text}'.\n" + f"Expected: {expected_tokens}, Got: {gpt2_tokens}") + + if errors: + pytest.fail("\n".join(errors))