diff --git a/README.md b/README.md index fd6f371..2099715 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,7 @@ Several folders contain optional materials as a bonus for interested readers: - [Gemma 3 From Scratch](ch05/12_gemma3/) - [Olmo 3 From Scratch](ch05/13_olmo3/) - [Tiny Aya From Scratch](ch05/15_tiny-aya/) + - [Qwen3.5 From Scratch](ch05/16_qwen3.5/) - [Chapter 5 with other LLMs as Drop-In Replacement (e.g., Llama 3, Qwen 3)](ch05/14_ch05_with_other_llms/) - **Chapter 6: Finetuning for classification** - [Additional Experiments Finetuning Different Layers and Using Larger Models](ch06/02_bonus_additional-experiments) diff --git a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb index a21a0fa..75d8ed7 100644 --- a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb @@ -82,9 +82,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "huggingface_hub version: 0.35.3\n", - "tokenizers version: 0.22.1\n", - "torch version: 2.8.0\n" + "huggingface_hub version: 1.5.0\n", + "tokenizers version: 0.22.2\n", + "torch version: 2.8.0+cu128\n" ] } ], @@ -659,16 +659,16 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.2256, -0.0164, -0.7070, ..., 0.4414, 0.1245, 1.0703],\n", - " [-0.6602, 0.5352, -0.0718, ..., -0.0737, 0.5391, 0.3086],\n", - " [-0.4785, -0.1562, 0.1045, ..., -0.2324, 0.2354, 0.6328]]],\n", + "tensor([[[-0.2334, -0.0134, -0.7031, ..., 0.4316, 0.1177, 1.0703],\n", + " [-0.6641, 0.5352, -0.0752, ..., -0.0698, 0.5430, 0.3203],\n", + " [-0.4785, -0.1748, 0.1074, ..., -0.2354, 0.2354, 0.6289]]],\n", " dtype=torch.bfloat16, grad_fn=)" ] }, @@ -922,16 +922,7 @@ "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import json\n", "import os\n", @@ -1182,16 +1173,26 @@ "\n", "Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.\n", "\n", - "I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or content generation. Also, emphasize their adaptability and efficiency. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n", + "I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and versatility. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n", "\n", "\n", - "Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and content creation. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for a wide range of applications." + "Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and answering questions. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for applications in various domains.\n", + "\n", + "Generation speed: 48.46 tokens/sec\n", + "GPU memory used: 1.50 GB\n" ] } ], "source": [ + "import time\n", + "\n", "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "start_time = time.perf_counter()\n", + "generated_tokens = 0\n", "\n", "for token in generate_text_basic_stream(\n", " model=model,\n", @@ -1199,12 +1200,23 @@ " max_new_tokens=500,\n", " eos_token_id=tokenizer.eos_token_id\n", "):\n", + " generated_tokens += 1\n", " token_id = token.squeeze(0).tolist()\n", " print(\n", " tokenizer.decode(token_id),\n", " end=\"\",\n", " flush=True\n", - " )" + " )\n", + "\n", + "elapsed = time.perf_counter() - start_time\n", + "tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n", + "print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n", + "\n", + "if torch.cuda.is_available():\n", + " def calc_gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + "\n", + " print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n" ] }, { diff --git a/ch05/11_qwen3/standalone-qwen3.ipynb b/ch05/11_qwen3/standalone-qwen3.ipynb index 55d64d2..fdc40fb 100644 --- a/ch05/11_qwen3/standalone-qwen3.ipynb +++ b/ch05/11_qwen3/standalone-qwen3.ipynb @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 21, "id": "7c201adb-747e-437b-9a62-442802941e01", "metadata": {}, "outputs": [], @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 22, "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", "metadata": { "colab": { @@ -80,9 +80,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "huggingface_hub version: 0.34.4\n", - "tokenizers version: 0.21.4\n", - "torch version: 2.8.0\n" + "huggingface_hub version: 1.5.0\n", + "tokenizers version: 0.22.2\n", + "torch version: 2.8.0+cu128\n" ] } ], @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 23, "id": "70a90338-624a-4706-aa55-6b4358070194", "metadata": {}, "outputs": [], @@ -142,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 24, "id": "82076c21-9331-4dcd-b017-42b046cf1a60", "metadata": { "id": "82076c21-9331-4dcd-b017-42b046cf1a60" @@ -169,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 25, "id": "56715760-37e1-433e-89da-04864c139a9e", "metadata": {}, "outputs": [], @@ -200,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 26, "id": "4b9a346f-5826-4083-9162-abd56afc03f0", "metadata": { "id": "4b9a346f-5826-4083-9162-abd56afc03f0" @@ -252,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 27, "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", "metadata": { "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" @@ -327,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 28, "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", "metadata": { "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" @@ -367,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 29, "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", "metadata": { "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" @@ -431,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 30, "id": "caa142fa-b375-4e78-b392-2072ced666f3", "metadata": { "id": "caa142fa-b375-4e78-b392-2072ced666f3" @@ -536,7 +536,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 31, "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", "metadata": { "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" @@ -549,7 +549,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 32, "id": "eaf86265-4e9d-4024-9ed0-99076944e304", "metadata": {}, "outputs": [ @@ -582,7 +582,7 @@ ")" ] }, - "execution_count": 12, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -601,20 +601,20 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 33, "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.2256, -0.0164, -0.7070, ..., 0.4414, 0.1245, 1.0703],\n", - " [-0.6602, 0.5352, -0.0718, ..., -0.0737, 0.5391, 0.3086],\n", - " [-0.4785, -0.1562, 0.1045, ..., -0.2324, 0.2354, 0.6328]]],\n", + "tensor([[[-0.2334, -0.0134, -0.7031, ..., 0.4316, 0.1177, 1.0703],\n", + " [-0.6641, 0.5352, -0.0752, ..., -0.0698, 0.5430, 0.3203],\n", + " [-0.4785, -0.1748, 0.1074, ..., -0.2354, 0.2354, 0.6289]]],\n", " dtype=torch.bfloat16, grad_fn=)" ] }, - "execution_count": 13, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -625,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 34, "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", "metadata": { "colab": { @@ -656,7 +656,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 35, "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", "metadata": { "colab": { @@ -706,7 +706,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 36, "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", "metadata": { "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" @@ -736,7 +736,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 37, "id": "75166128-5899-4995-9b88-9672e135650e", "metadata": { "id": "75166128-5899-4995-9b88-9672e135650e" @@ -841,7 +841,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 38, "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "metadata": { "colab": { @@ -864,7 +864,22 @@ "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b83bbaf857414e8b8842a6af8bfe3071", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 0%| | 0.00/1.50G [00:00user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n'" ] }, - "execution_count": 21, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -1069,7 +1099,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 42, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": { "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" @@ -1095,7 +1125,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 43, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": { "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" @@ -1108,16 +1138,26 @@ "\n", "Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.\n", "\n", - "I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or content generation. Also, emphasize their adaptability and efficiency. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n", + "I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and versatility. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n", "\n", "\n", - "Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and content creation. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for applications in various domains." + "Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and answering questions. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for a wide range of applications.\n", + "\n", + "Generation speed: 32.84 tokens/sec\n", + "GPU memory used: 3.03 GB\n" ] } ], "source": [ + "import time\n", + "\n", "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "start_time = time.perf_counter()\n", + "generated_tokens = 0\n", "\n", "for token in generate_text_basic_stream(\n", " model=model,\n", @@ -1125,12 +1165,23 @@ " max_new_tokens=500,\n", " eos_token_id=tokenizer.eos_token_id\n", "):\n", + " generated_tokens += 1\n", " token_id = token.squeeze(0).tolist()\n", " print(\n", " tokenizer.decode(token_id),\n", " end=\"\",\n", " flush=True\n", - " )" + " )\n", + "\n", + "elapsed = time.perf_counter() - start_time\n", + "tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n", + "print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n", + "\n", + "if torch.cuda.is_available():\n", + " def calc_gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + "\n", + " print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")" ] }, { diff --git a/ch05/16_qwen3.5/README.md b/ch05/16_qwen3.5/README.md new file mode 100644 index 0000000..0a55245 --- /dev/null +++ b/ch05/16_qwen3.5/README.md @@ -0,0 +1,20 @@ +# Qwen3.5 0.8B From Scratch + +This folder contains a from-scratch style implementation of [Qwen/Qwen3.5-0.8B](https://huggingface.co/Qwen/Qwen3.5-0.8B). + + + +Qwen3.5 is based on the Qwen3-Next architecture, which I described in more detail in section [2. (Linear) Attention Hybrids](https://magazine.sebastianraschka.com/i/177848019/2-linear-attention-hybrids) of my [Beyond Standard LLMs](https://magazine.sebastianraschka.com/p/beyond-standard-llms) article + + + +Note that Qwen3.5 alternates `linear_attention` and `full_attention` layers. +The notebooks keep the full model flow readable while reusing the linear-attention building blocks from the [qwen3_5_transformers.py](qwen3_5_transformers.py), which contains the linear attention code from Hugging Face under an Apache version 2.0 open source license. + +  +## Files + +- [qwen3.5.ipynb](qwen3.5.ipynb): Main Qwen3.5 0.8B notebook implementation. +- [qwen3.5-plus-kv-cache.ipynb](qwen3.5-plus-kv-cache.ipynb): Same model with KV-cache decoding for efficiency. +- [qwen3_5_transformers.py](qwen3_5_transformers.py): Some helper components from Hugging Face Transformers used for Qwen3.5 linear attention. + diff --git a/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb b/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb new file mode 100644 index 0000000..929de3d --- /dev/null +++ b/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb @@ -0,0 +1,1708 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c", + "metadata": { + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", + "metadata": { + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52" + }, + "source": [ + "# Qwen3.5 From Scratch" + ] + }, + { + "cell_type": "markdown", + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", + "metadata": { + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d" + }, + "source": [ + "- This notebook is purposefully minimal and focuses on a readable re-implementation of the Qwen3.5 text stack for the [Qwen/Qwen3.5-0.8B on Hugging Face](https://huggingface.co/Qwen/Qwen3.5-0.8B) checkpoint that maps it onto the scaffold I used for the other from-scratch implementations in this repo\n", + "- Qwen3.5 alternates `linear_attention` and `full_attention` layers\n", + "- Note that this notebook is not 100% standalone & from-scratch as it re-uses some code (i.e., the `Qwen3_5GatedDeltaNet` for the linear attention layers) from the Hugging Face transformers library; the relevant parts are inside the [qwen3_5_transformers.py](qwen3_5_transformers.py) file" + ] + }, + { + "cell_type": "markdown", + "id": "b304d453-f7da-4e17-8330-3a08a67ae3b1", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "id": "1241a20b-d196-4521-9228-d46954d383e4", + "metadata": {}, + "source": [ + "- Qwen3.5 is based on the Qwen3-Next architecture, which I described in more detail in section [2. (Linear) Attention Hybrids](https://magazine.sebastianraschka.com/i/177848019/2-linear-attention-hybrids) of my [Beyond Standard LLMs](https://magazine.sebastianraschka.com/p/beyond-standard-llms) article" + ] + }, + { + "cell_type": "markdown", + "id": "21d38944-0c98-40a6-a6f8-c745769b4618", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7c201adb-747e-437b-9a62-442802941e01", + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface_hub version: 1.5.0\n", + "tokenizers version: 0.22.2\n", + "torch version: 2.8.0+cu128\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"tokenizers\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "70a90338-624a-4706-aa55-6b4358070194", + "metadata": {}, + "outputs": [], + "source": [ + "USE_MODEL = \"Qwen3.5-0.8B\"" + ] + }, + { + "cell_type": "markdown", + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", + "metadata": { + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d" + }, + "source": [ + " \n", + "# 1. Architecture code" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "82076c21-9331-4dcd-b017-42b046cf1a60", + "metadata": { + "id": "82076c21-9331-4dcd-b017-42b046cf1a60" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + "\n", + " def forward(self, x):\n", + " x_fc1 = self.fc1(x)\n", + " x_fc2 = self.fc2(x)\n", + " x = nn.functional.silu(x_fc1) * x_fc2\n", + " return self.fc3(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "56715760-37e1-433e-89da-04864c139a9e", + "metadata": {}, + "outputs": [], + "source": [ + "class RMSNorm(nn.Module):\n", + " def __init__(self, emb_dim, eps=1e-6):\n", + " super().__init__()\n", + " self.eps = eps\n", + " # Qwen3.5 uses (1 + weight) scaling with zero init\n", + " self.weight = nn.Parameter(torch.zeros(emb_dim))\n", + "\n", + " def _norm(self, x):\n", + " return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n", + "\n", + " def forward(self, x):\n", + " x_norm = self._norm(x.float())\n", + " x_norm = x_norm * (1.0 + self.weight.float())\n", + " return x_norm.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4b9a346f-5826-4083-9162-abd56afc03f0", + "metadata": { + "id": "4b9a346f-5826-4083-9162-abd56afc03f0" + }, + "outputs": [], + "source": [ + "def compute_rope_params(\n", + " head_dim,\n", + " theta_base=10_000,\n", + " context_length=4096,\n", + " partial_rotary_factor=1.0,\n", + " dtype=torch.float32,\n", + "):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " rotary_dim = int(head_dim * partial_rotary_factor)\n", + " rotary_dim = max(2, rotary_dim - (rotary_dim % 2))\n", + "\n", + " inv_freq = 1.0 / (\n", + " theta_base ** (\n", + " torch.arange(0, rotary_dim, 2, dtype=dtype)[: (rotary_dim // 2)].float() / rotary_dim\n", + " )\n", + " )\n", + "\n", + " positions = torch.arange(context_length, dtype=dtype)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n", + " angles = torch.cat([angles, angles], dim=1)\n", + "\n", + " cos = torch.cos(angles)\n", + " sin = torch.sin(angles)\n", + "\n", + " return cos, sin\n", + "\n", + "\n", + "def apply_rope(x, cos, sin, offset=0):\n", + " _, _, seq_len, head_dim = x.shape\n", + " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", + "\n", + " rot_dim = cos.shape[-1]\n", + " if rot_dim > head_dim:\n", + " raise ValueError(f\"RoPE dim {rot_dim} cannot exceed head_dim {head_dim}.\")\n", + "\n", + " x_rot = x[..., :rot_dim]\n", + " x_pass = x[..., rot_dim:]\n", + "\n", + " x1 = x_rot[..., : rot_dim // 2]\n", + " x2 = x_rot[..., rot_dim // 2 :]\n", + "\n", + " cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n", + " sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n", + "\n", + " rotated = torch.cat((-x2, x1), dim=-1)\n", + " x_rotated = (x_rot * cos) + (rotated * sin)\n", + "\n", + " x_out = torch.cat([x_rotated, x_pass], dim=-1)\n", + " return x_out.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", + "metadata": { + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" + }, + "outputs": [], + "source": [ + "class GroupedQueryAttention(nn.Module):\n", + " def __init__(\n", + " self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None\n", + " ):\n", + " super().__init__()\n", + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.num_kv_groups = num_kv_groups\n", + " self.group_size = num_heads // num_kv_groups\n", + "\n", + " if head_dim is None:\n", + " assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n", + " head_dim = d_in // num_heads\n", + "\n", + " self.head_dim = head_dim\n", + " self.d_out = num_heads * head_dim\n", + "\n", + " # Qwen3.5 full-attention uses a gated Q projection (2x output dim)\n", + " self.W_query = nn.Linear(d_in, self.d_out * 2, bias=False, dtype=dtype)\n", + " self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n", + "\n", + " self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n", + "\n", + " if qk_norm:\n", + " self.q_norm = RMSNorm(head_dim, eps=1e-6)\n", + " self.k_norm = RMSNorm(head_dim, eps=1e-6)\n", + " else:\n", + " self.q_norm = self.k_norm = None\n", + "\n", + " def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n", + " b, num_tokens, _ = x.shape\n", + "\n", + " q_and_gate = self.W_query(x)\n", + " q_and_gate = q_and_gate.view(b, num_tokens, self.num_heads, self.head_dim * 2)\n", + " queries, gate = torch.chunk(q_and_gate, 2, dim=-1)\n", + " gate = gate.reshape(b, num_tokens, self.d_out)\n", + "\n", + " keys = self.W_key(x)\n", + " values = self.W_value(x)\n", + "\n", + " queries = queries.transpose(1, 2)\n", + " keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + " values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + "\n", + " if self.q_norm:\n", + " queries = self.q_norm(queries)\n", + " if self.k_norm:\n", + " keys_new = self.k_norm(keys_new)\n", + "\n", + " prev_len = 0\n", + " if cache is not None:\n", + " prev_k, prev_v = cache\n", + " if prev_k is not None:\n", + " prev_len = prev_k.size(2)\n", + " keys_cat_raw = torch.cat([prev_k, keys_new], dim=2)\n", + " values_cat_raw = torch.cat([prev_v, values_new], dim=2)\n", + " else:\n", + " keys_cat_raw = keys_new\n", + " values_cat_raw = values_new\n", + " else:\n", + " keys_cat_raw = keys_new\n", + " values_cat_raw = values_new\n", + "\n", + " queries = apply_rope(queries, cos, sin, offset=start_pos)\n", + " keys = apply_rope(keys_cat_raw, cos, sin, offset=start_pos - prev_len)\n", + "\n", + " keys = keys.repeat_interleave(self.group_size, dim=1)\n", + " values = values_cat_raw.repeat_interleave(self.group_size, dim=1)\n", + "\n", + " if cache is not None and cache[0] is not None:\n", + " next_cache = (\n", + " torch.cat([cache[0], keys_new], dim=2),\n", + " torch.cat([cache[1], values_new], dim=2),\n", + " )\n", + " else:\n", + " next_cache = (keys_new, values_new)\n", + "\n", + " attn_scores = queries @ keys.transpose(2, 3)\n", + " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n", + " attn_weights = torch.softmax(\n", + " attn_scores * (self.head_dim ** -0.5),\n", + " dim=-1,\n", + " dtype=torch.float32,\n", + " ).to(queries.dtype)\n", + "\n", + " context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n", + "\n", + " # Qwen3.5 full-attention uses a gated Q projection\n", + " context = context * torch.sigmoid(gate)\n", + " out = self.out_proj(context)\n", + " return out, next_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", + "metadata": { + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" + }, + "outputs": [], + "source": [ + "from qwen3_5_transformers import (\n", + " Qwen3_5GatedDeltaNet,\n", + ")\n", + "\n", + "# Just a mapping for the different naming convention in Hugging Face transformers\n", + "class _Qwen3_5ConfigAdapter:\n", + " def __init__(self, cfg):\n", + " self.hidden_size = cfg[\"emb_dim\"]\n", + " self.linear_num_value_heads = cfg[\"linear_num_value_heads\"]\n", + " self.linear_num_key_heads = cfg[\"linear_num_key_heads\"]\n", + " self.linear_key_head_dim = cfg[\"linear_key_head_dim\"]\n", + " self.linear_value_head_dim = cfg[\"linear_value_head_dim\"]\n", + " self.linear_conv_kernel_dim = cfg[\"linear_conv_kernel_dim\"]\n", + " self.hidden_act = \"silu\"\n", + " self.rms_norm_eps = cfg.get(\"rms_norm_eps\", 1e-6)\n", + " self.dtype = cfg.get(\"dtype\", None)\n", + "\n", + "\n", + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg, layer_type, layer_idx):\n", + " super().__init__()\n", + " self.layer_type = layer_type\n", + "\n", + " if layer_type == \"full_attention\":\n", + " self.token_mixer = GroupedQueryAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " head_dim=cfg[\"head_dim\"],\n", + " num_kv_groups=cfg[\"n_kv_groups\"],\n", + " qk_norm=cfg[\"qk_norm\"],\n", + " dtype=cfg[\"dtype\"],\n", + " )\n", + " elif layer_type == \"linear_attention\":\n", + " self.token_mixer = Qwen3_5GatedDeltaNet(_Qwen3_5ConfigAdapter(cfg), layer_idx)\n", + " else:\n", + " raise ValueError(f\"Unsupported layer type: {layer_type}\")\n", + "\n", + " self.ff = FeedForward(cfg)\n", + " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n", + " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n", + "\n", + " def forward(self, x, mask, cos, sin, start_pos=0, cache=None, linear_cache=None, cache_position=None):\n", + " shortcut = x\n", + " x = self.norm1(x)\n", + "\n", + " if self.layer_type == \"full_attention\":\n", + " x, next_cache = self.token_mixer(\n", + " x,\n", + " mask,\n", + " cos,\n", + " sin,\n", + " start_pos=start_pos,\n", + " cache=cache,\n", + " )\n", + " else:\n", + " x = self.token_mixer(\n", + " x,\n", + " cache_params=linear_cache,\n", + " cache_position=cache_position,\n", + " )\n", + " next_cache = None\n", + "\n", + " x = x + shortcut\n", + "\n", + " shortcut = x\n", + " x = self.norm2(x)\n", + " x = self.ff(x)\n", + " x = x + shortcut\n", + "\n", + " return x, next_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", + "metadata": { + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" + }, + "outputs": [], + "source": [ + "class Qwen3_5Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + "\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + "\n", + " layer_types = cfg.get(\"layer_types\", [\"full_attention\"] * cfg[\"n_layers\"])\n", + " if len(layer_types) != cfg[\"n_layers\"]:\n", + " raise ValueError(\"len(layer_types) must equal n_layers\")\n", + "\n", + " self.trf_blocks = nn.ModuleList(\n", + " [TransformerBlock(cfg, layer_type, idx) for idx, layer_type in enumerate(layer_types)]\n", + " )\n", + "\n", + " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + "\n", + " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"] if cfg[\"head_dim\"] is None else cfg[\"head_dim\"]\n", + " cos, sin = compute_rope_params(\n", + " head_dim=head_dim,\n", + " theta_base=cfg[\"rope_base\"],\n", + " context_length=cfg[\"context_length\"],\n", + " partial_rotary_factor=cfg.get(\"partial_rotary_factor\", 1.0),\n", + " dtype=torch.float32,\n", + " )\n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", + " self.cfg = cfg\n", + " self.current_pos = 0\n", + "\n", + " def create_mask(self, cur_len, device, pos_start=0, pos_end=None):\n", + " if pos_end is None:\n", + " pos_end = cur_len\n", + "\n", + " ones = torch.ones((pos_end, pos_end), device=device, dtype=torch.bool)\n", + " mask_full = torch.triu(ones, diagonal=1)\n", + " row_slice = slice(pos_start, pos_end)\n", + " mask = mask_full[row_slice, :pos_end][None, None, :, :]\n", + " return mask\n", + "\n", + " def forward(self, in_idx, cache=None):\n", + " x = self.tok_emb(in_idx)\n", + "\n", + " num_tokens = x.shape[1]\n", + " if cache is not None:\n", + " pos_start = self.current_pos\n", + " pos_end = pos_start + num_tokens\n", + " self.current_pos = pos_end\n", + " mask = self.create_mask(\n", + " cur_len=num_tokens,\n", + " device=x.device,\n", + " pos_start=pos_start,\n", + " pos_end=pos_end,\n", + " )\n", + " cache_position = torch.arange(pos_start, pos_end, device=x.device, dtype=torch.long)\n", + " else:\n", + " pos_start = 0\n", + " mask = self.create_mask(\n", + " cur_len=num_tokens,\n", + " device=x.device,\n", + " pos_start=0,\n", + " pos_end=num_tokens,\n", + " )\n", + " cache_position = None\n", + "\n", + " for i, block in enumerate(self.trf_blocks):\n", + " blk_cache = cache.get(i) if cache is not None else None\n", + " x, new_blk_cache = block(\n", + " x,\n", + " mask=mask,\n", + " cos=self.cos,\n", + " sin=self.sin,\n", + " start_pos=pos_start,\n", + " cache=blk_cache,\n", + " linear_cache=cache.linear_cache if cache is not None else None,\n", + " cache_position=cache_position,\n", + " )\n", + " if cache is not None and new_blk_cache is not None:\n", + " cache.update(i, new_blk_cache)\n", + "\n", + " if cache is not None:\n", + " cache.linear_cache.has_previous_state = True\n", + "\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", + " return logits\n", + "\n", + " def reset_kv_cache(self):\n", + " self.current_pos = 0\n", + "\n", + "\n", + "class Qwen3_5LinearAttentionCache:\n", + " def __init__(self, n_layers):\n", + " self.conv_states = [None] * n_layers\n", + " self.recurrent_states = [None] * n_layers\n", + " self.has_previous_state = False\n", + "\n", + " def reset(self):\n", + " for i in range(len(self.conv_states)):\n", + " self.conv_states[i] = None\n", + " self.recurrent_states[i] = None\n", + " self.has_previous_state = False\n", + "\n", + "\n", + "class KVCache:\n", + " def __init__(self, n_layers):\n", + " self.cache = [None] * n_layers\n", + " self.linear_cache = Qwen3_5LinearAttentionCache(n_layers)\n", + "\n", + " def get(self, layer_idx):\n", + " return self.cache[layer_idx]\n", + "\n", + " def update(self, layer_idx, value):\n", + " self.cache[layer_idx] = value\n", + "\n", + " def get_all(self):\n", + " return self.cache\n", + "\n", + " def reset(self):\n", + " for i in range(len(self.cache)):\n", + " self.cache[i] = None\n", + " self.linear_cache.reset()" + ] + }, + { + "cell_type": "markdown", + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", + "metadata": { + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48" + }, + "source": [ + " \n", + "# 2. Initialize model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "caa142fa-b375-4e78-b392-2072ced666f3", + "metadata": { + "id": "caa142fa-b375-4e78-b392-2072ced666f3" + }, + "outputs": [], + "source": [ + "# Qwen3.5-0.8B text configuration\n", + "QWEN3_5_CONFIG = {\n", + " \"vocab_size\": 248_320,\n", + " \"context_length\": 262_144,\n", + " \"emb_dim\": 1_024,\n", + " \"n_heads\": 8,\n", + " \"n_layers\": 24,\n", + " \"hidden_dim\": 3_584,\n", + " \"head_dim\": 256,\n", + " \"qk_norm\": True,\n", + " \"n_kv_groups\": 2,\n", + " \"rope_base\": 10_000_000.0,\n", + " \"partial_rotary_factor\": 0.25,\n", + " \"rms_norm_eps\": 1e-6,\n", + " \"linear_conv_kernel_dim\": 4,\n", + " \"linear_key_head_dim\": 128,\n", + " \"linear_value_head_dim\": 128,\n", + " \"linear_num_key_heads\": 16,\n", + " \"linear_num_value_heads\": 16,\n", + " \"dtype\": torch.bfloat16,\n", + " \"layer_types\": [\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " ],\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", + "metadata": { + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d\n" + ] + } + ], + "source": [ + "torch.manual_seed(123)\n", + "model = Qwen3_5Model(QWEN3_5_CONFIG)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "eaf86265-4e9d-4024-9ed0-99076944e304", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Qwen3_5Model(\n", + " (tok_emb): Embedding(248320, 1024)\n", + " (trf_blocks): ModuleList(\n", + " (0-2): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (3): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (4-6): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (7): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (8-10): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (11): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (12-14): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (15): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (16-18): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (19): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (20-22): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (23): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " )\n", + " (final_norm): RMSNorm()\n", + " (out_head): Linear(in_features=1024, out_features=248320, bias=False)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b", + "metadata": {}, + "source": [ + "- A quick check that the forward pass works before continuing:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-0.6719, -0.0347, -0.5938, ..., 0.5469, 0.1660, -0.8945],\n", + " [ 0.0391, -0.1226, -0.8789, ..., -0.6523, -0.8281, -0.0889],\n", + " [ 0.1992, -0.7930, -0.3359, ..., -0.6602, 0.0515, -0.1582]]],\n", + " dtype=torch.bfloat16, grad_fn=)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(torch.tensor([1, 2, 3]).unsqueeze(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", + "outputId": "00d7e983-262e-4c65-f322-f4d999311988" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters: 1,006,672,704\n", + "\n", + "Total number of unique parameters: 752,393,024\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total number of parameters: {total_params:,}\")\n", + "\n", + "# Account for weight tying\n", + "total_params_normalized = total_params - model.tok_emb.weight.numel()\n", + "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", + "outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float32 (PyTorch default): 7.63 GB\n", + "bfloat16: 3.81 GB\n" + ] + } + ], + "source": [ + "def calc_model_memory_size(model, input_dtype=torch.float32):\n", + " total_params = 0\n", + " total_grads = 0\n", + " for param in model.parameters():\n", + " # Calculate total number of elements per parameter\n", + " param_size = param.numel()\n", + " total_params += param_size\n", + " # Check if gradients are stored for this parameter\n", + " if param.requires_grad:\n", + " total_grads += param_size\n", + "\n", + " # Calculate buffer size (non-parameters that require memory)\n", + " total_buffers = sum(buf.numel() for buf in model.buffers())\n", + "\n", + " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", + " # We assume parameters and gradients are stored in the same type as input dtype\n", + " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", + " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", + "\n", + " # Convert bytes to gigabytes\n", + " total_memory_gb = total_memory_bytes / (1024**3)\n", + "\n", + " return total_memory_gb\n", + "\n", + "print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", + "print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", + "metadata": { + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "c172f89f-d301-439f-b809-46169e5f5945", + "metadata": { + "id": "c172f89f-d301-439f-b809-46169e5f5945" + }, + "source": [ + " \n", + "# 3. Load pretrained weights" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "75166128-5899-4995-9b88-9672e135650e", + "metadata": { + "id": "75166128-5899-4995-9b88-9672e135650e" + }, + "outputs": [], + "source": [ + "def load_weights_into_qwen3_5(model, param_config, params):\n", + " def assign(left, right, tensor_name=\"unknown\"):\n", + " if left.shape != right.shape:\n", + " raise ValueError(\n", + " f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n", + " )\n", + "\n", + " with torch.no_grad():\n", + " if isinstance(right, torch.Tensor):\n", + " left.copy_(right)\n", + " else:\n", + " left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n", + "\n", + " return left\n", + "\n", + " if \"model.embed_tokens.weight\" in params:\n", + " model_prefix = \"model\"\n", + " elif \"model.language_model.embed_tokens.weight\" in params:\n", + " model_prefix = \"model.language_model\"\n", + " else:\n", + " raise KeyError(\"Could not find embed token weights in checkpoint.\")\n", + "\n", + " def pkey(suffix):\n", + " return f\"{model_prefix}.{suffix}\"\n", + "\n", + " model.tok_emb.weight = assign(\n", + " model.tok_emb.weight,\n", + " params[pkey(\"embed_tokens.weight\")],\n", + " pkey(\"embed_tokens.weight\"),\n", + " )\n", + "\n", + " n_layers = param_config[\"n_layers\"]\n", + " layer_types = param_config.get(\"layer_types\", [\"full_attention\"] * n_layers)\n", + "\n", + " for l in range(n_layers):\n", + " block = model.trf_blocks[l]\n", + " layer_type = layer_types[l]\n", + "\n", + " if layer_type == \"full_attention\":\n", + " att = block.token_mixer\n", + " att.W_query.weight = assign(\n", + " att.W_query.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.q_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.q_proj.weight\"),\n", + " )\n", + " att.W_key.weight = assign(\n", + " att.W_key.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.k_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.k_proj.weight\"),\n", + " )\n", + " att.W_value.weight = assign(\n", + " att.W_value.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.v_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.v_proj.weight\"),\n", + " )\n", + " att.out_proj.weight = assign(\n", + " att.out_proj.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.o_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.o_proj.weight\"),\n", + " )\n", + " if hasattr(att, \"q_norm\") and att.q_norm is not None:\n", + " att.q_norm.weight = assign(\n", + " att.q_norm.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.q_norm.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.q_norm.weight\"),\n", + " )\n", + " if hasattr(att, \"k_norm\") and att.k_norm is not None:\n", + " att.k_norm.weight = assign(\n", + " att.k_norm.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.k_norm.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.k_norm.weight\"),\n", + " )\n", + "\n", + " elif layer_type == \"linear_attention\":\n", + " lat = block.token_mixer\n", + " lat.dt_bias = assign(\n", + " lat.dt_bias,\n", + " params[pkey(f\"layers.{l}.linear_attn.dt_bias\")],\n", + " pkey(f\"layers.{l}.linear_attn.dt_bias\"),\n", + " )\n", + " lat.A_log = assign(\n", + " lat.A_log,\n", + " params[pkey(f\"layers.{l}.linear_attn.A_log\")],\n", + " pkey(f\"layers.{l}.linear_attn.A_log\"),\n", + " )\n", + " lat.conv1d.weight = assign(\n", + " lat.conv1d.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.conv1d.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.conv1d.weight\"),\n", + " )\n", + " lat.norm.weight = assign(\n", + " lat.norm.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.norm.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.norm.weight\"),\n", + " )\n", + " lat.out_proj.weight = assign(\n", + " lat.out_proj.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.out_proj.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.out_proj.weight\"),\n", + " )\n", + " lat.in_proj_qkv.weight = assign(\n", + " lat.in_proj_qkv.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_qkv.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_qkv.weight\"),\n", + " )\n", + " lat.in_proj_z.weight = assign(\n", + " lat.in_proj_z.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_z.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_z.weight\"),\n", + " )\n", + " lat.in_proj_b.weight = assign(\n", + " lat.in_proj_b.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_b.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_b.weight\"),\n", + " )\n", + " lat.in_proj_a.weight = assign(\n", + " lat.in_proj_a.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_a.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_a.weight\"),\n", + " )\n", + "\n", + " else:\n", + " raise ValueError(f\"Unsupported layer type: {layer_type}\")\n", + "\n", + " block.norm1.weight = assign(\n", + " block.norm1.weight,\n", + " params[pkey(f\"layers.{l}.input_layernorm.weight\")],\n", + " pkey(f\"layers.{l}.input_layernorm.weight\"),\n", + " )\n", + "\n", + " block.ff.fc1.weight = assign(\n", + " block.ff.fc1.weight,\n", + " params[pkey(f\"layers.{l}.mlp.gate_proj.weight\")],\n", + " pkey(f\"layers.{l}.mlp.gate_proj.weight\"),\n", + " )\n", + " block.ff.fc2.weight = assign(\n", + " block.ff.fc2.weight,\n", + " params[pkey(f\"layers.{l}.mlp.up_proj.weight\")],\n", + " pkey(f\"layers.{l}.mlp.up_proj.weight\"),\n", + " )\n", + " block.ff.fc3.weight = assign(\n", + " block.ff.fc3.weight,\n", + " params[pkey(f\"layers.{l}.mlp.down_proj.weight\")],\n", + " pkey(f\"layers.{l}.mlp.down_proj.weight\"),\n", + " )\n", + " block.norm2.weight = assign(\n", + " block.norm2.weight,\n", + " params[pkey(f\"layers.{l}.post_attention_layernorm.weight\")],\n", + " pkey(f\"layers.{l}.post_attention_layernorm.weight\"),\n", + " )\n", + "\n", + " model.final_norm.weight = assign(\n", + " model.final_norm.weight,\n", + " params[pkey(\"norm.weight\")],\n", + " pkey(\"norm.weight\"),\n", + " )\n", + "\n", + " if \"lm_head.weight\" in params:\n", + " model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n", + " elif pkey(\"lm_head.weight\") in params:\n", + " model.out_head.weight = assign(model.out_head.weight, params[pkey(\"lm_head.weight\")], pkey(\"lm_head.weight\"))\n", + " else:\n", + " model.out_head.weight = model.tok_emb.weight\n", + " print(\"Model uses weight tying.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "9881b6995c3f49dc89e6992fd9ab660b", + "17a3174e65c54476b2e0d1faf8f011ca", + "1bbf2e62c0754d1593beb4105a7f1ac1", + "b82112e1dec645d98aa1c1ba64abcb61", + "271e2bd6a35e4a8b92de8697f7c0be5f", + "90a79523187446dfa692723b2e5833a7", + "431ffb83b8c14bf182f0430e07ea6154", + "a8f1b72a33dd4b548de23fbd95e0da18", + "25cc36132d384189acfbecc59483134b", + "bfd06423ad544218968648016e731a46", + "d029630b63ff44cf807ade428d2eb421" + ] + }, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6ca01175c472450786e4ae0201a39beb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading (incomplete total...): 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "348a2193fba34101a79816dc808e8533", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 13 files: 0%| | 0/13 [00:00\",\n", + " \"<|im_start|>\", \"<|im_end|>\",\n", + " \"<|object_ref_start|>\", \"<|object_ref_end|>\",\n", + " \"<|box_start|>\", \"<|box_end|>\",\n", + " \"<|quad_start|>\", \"<|quad_end|>\",\n", + " \"<|vision_start|>\", \"<|vision_end|>\",\n", + " \"<|vision_pad|>\", \"<|image_pad|>\", \"<|video_pad|>\",\n", + " \"\", \"\",\n", + " ]\n", + " _SPLIT_RE = re.compile(r\"(<\\|[^>]+?\\|>||)\")\n", + "\n", + " def __init__(\n", + " self,\n", + " tokenizer_file_path=\"tokenizer.json\",\n", + " repo_id=None,\n", + " apply_chat_template=True,\n", + " add_generation_prompt=False,\n", + " add_thinking=False,\n", + " ):\n", + " self.apply_chat_template = apply_chat_template\n", + " self.add_generation_prompt = add_generation_prompt\n", + " self.add_thinking = add_thinking\n", + "\n", + " tok_file = Path(tokenizer_file_path)\n", + " self._tok = Tokenizer.from_file(str(tok_file))\n", + " self._special_to_id = {}\n", + " for t in self._SPECIALS:\n", + " tid = self._tok.token_to_id(t)\n", + " if tid is not None:\n", + " self._special_to_id[t] = tid\n", + "\n", + " self.pad_token_id = self._special_to_id[\"<|endoftext|>\"]\n", + " self.eos_token_id = self.pad_token_id\n", + "\n", + " if repo_id and \"Base\" not in repo_id:\n", + " eos_token = \"<|im_end|>\"\n", + " else:\n", + " eos_token = \"<|endoftext|>\"\n", + " if eos_token in self._special_to_id:\n", + " self.eos_token_id = self._special_to_id[eos_token]\n", + "\n", + " def encode(self, text, chat_wrapped=None):\n", + " if chat_wrapped is None:\n", + " chat_wrapped = self.apply_chat_template\n", + "\n", + " stripped = text.strip()\n", + " if stripped in self._special_to_id and \"\\n\" not in stripped:\n", + " return [self._special_to_id[stripped]]\n", + "\n", + " if chat_wrapped:\n", + " text = self._wrap_chat(text)\n", + "\n", + " ids = []\n", + " for part in filter(None, self._SPLIT_RE.split(text)):\n", + " if part in self._special_to_id:\n", + " ids.append(self._special_to_id[part])\n", + " else:\n", + " ids.extend(self._tok.encode(part).ids)\n", + " return ids\n", + "\n", + " def decode(self, ids):\n", + " return self._tok.decode(ids, skip_special_tokens=False)\n", + "\n", + " def _wrap_chat(self, user_msg):\n", + " # Mirrors Qwen3.5 chat_template behavior:\n", + " # add_generation_prompt + thinking => \"\\n\"\n", + " # add_generation_prompt + no thinking => empty think scaffold\n", + " s = f\"<|im_start|>user\\n{user_msg}<|im_end|>\\n\"\n", + " if self.add_generation_prompt:\n", + " s += \"<|im_start|>assistant\\n\"\n", + " if self.add_thinking:\n", + " s += \"\\n\"\n", + " else:\n", + " s += \"\\n\\n\\n\\n\"\n", + " return s\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_file_path = \"Qwen3.5-0.8B/tokenizer.json\"\n", + "\n", + "hf_hub_download(\n", + " repo_id=repo_id,\n", + " filename=\"tokenizer.json\",\n", + " local_dir=local_dir,\n", + ")\n", + "\n", + "tokenizer = Qwen3_5Tokenizer(\n", + " tokenizer_file_path=tokenizer_file_path,\n", + " repo_id=repo_id,\n", + " apply_chat_template=True,\n", + " add_generation_prompt=True,\n", + " add_thinking=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1946b534-e3af-431a-a222-391a60bfa892", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n\\n'" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = \"Give me a short introduction to large language models.\"\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "text = tokenizer.decode(input_token_ids)\n", + "text" + ] + }, + { + "cell_type": "markdown", + "id": "57d07df1-4401-4792-b549-7c4cc5632323", + "metadata": { + "id": "57d07df1-4401-4792-b549-7c4cc5632323" + }, + "source": [ + " \n", + "# 4. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", + "metadata": { + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" + }, + "outputs": [], + "source": [ + "def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None):\n", + "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for _ in range(max_new_tokens):\n", + " out = model(token_ids)[:, -1]\n", + " next_token = torch.argmax(out, dim=-1, keepdim=True)\n", + "\n", + " if (eos_token_id is not None\n", + " and torch.all(next_token == eos_token_id)):\n", + " break\n", + "\n", + " yield next_token\n", + " \n", + " token_ids = torch.cat([token_ids, next_token], dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", + "metadata": { + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Thinking Process:\n", + "\n", + "1. **Analyze the Request:**\n", + " * **Topic:** Large Language Models (LLMs).\n", + " * **Task:** Give a short introduction.\n", + " * **Constraint:** \"Short\" (implies concise, clear, and impactful).\n", + "\n", + "2. **Identify Key Concepts:**\n", + " * What are they? (AI models trained on massive datasets).\n", + " * What do they do? (Generate text, code, etc.).\n", + " * How do they work? (Neural networks, transformers, training).\n", + " * Why are they important? (Efficiency, context, creativity).\n", + " * *Self-Correction/Refinement:* Keep it simple but accurate. Avoid overly technical jargon unless necessary, but \"transformers\" is a key term.\n", + "\n", + "3. **Drafting - Attempt 1 (Mental Outline):**\n", + " LLMs are big AI models. They are trained on huge amounts of data. They can understand and generate text. They are like a supercomputer for language. They are used in chatbots and coding.\n", + "\n", + "4. **Drafting - Attempt 2 (Adding Detail & Flow):**\n", + " Large Language Models (LLMs) are a type of artificial intelligence. They are trained on massive datasets of text. They use neural networks to understand and generate human-like text. They are used in chatbots, coding assistants, and creative writing. They are becoming more powerful and efficient.\n", + "\n", + "5. **Drafting - Attempt 3 (Polishing for \"Short Introduction\"):**\n", + " Large Language Models (LLMs) are a type of artificial intelligence that can understand and generate human-like text. They are trained on massive datasets of text. They use neural networks to process information and create content. They are used in chatbots, coding assistants, and creative writing. They are becoming more powerful and efficient.\n", + "\n", + "6. **Refining for Clarity and Impact:**\n", + " * Make it punchy.\n", + " * Highlight the \"transformers\" or \"neural networks\" aspect if needed, but keep it simple.\n", + " * Mention the \"big data\" aspect.\n", + "\n", + "7. **Final Polish (incorporating into the final output):**\n", + " * Start with a definition.\n", + " * Mention the core technology (neural networks).\n", + " * Mention the output\n", + "\n", + "Generation speed: 8.25 tokens/sec\n", + "GPU memory used: 2.54 GB\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "prompt = \"Give me a short introduction to large language models.\"\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "start_time = time.perf_counter()\n", + "generated_tokens = 0\n", + "\n", + "for token in generate_text_basic_stream(\n", + " model=model,\n", + " token_ids=input_token_ids_tensor,\n", + " max_new_tokens=500,\n", + " eos_token_id=tokenizer.eos_token_id\n", + "):\n", + " generated_tokens += 1\n", + " token_id = token.squeeze(0).tolist()\n", + " print(\n", + " tokenizer.decode(token_id),\n", + " end=\"\",\n", + " flush=True\n", + " )\n", + "\n", + "elapsed = time.perf_counter() - start_time\n", + "tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n", + "print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n", + "\n", + "if torch.cuda.is_available():\n", + " def calc_gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + "\n", + " print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b0ef78d8-e512-47c2-aaab-d236a6e7cad3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Here's" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " a thinking process that leads to the solution:\n", + "\n", + "1. **Analyze the Request:**\n", + " * **Scenario:** A shop applies two discounts and a tax.\n", + " * **Discount:** 20% off the original price.\n", + " * **Tax:** 10% added on top of the discounted price.\n", + " * **Question:** Is the final price higher or lower than the original? By how much?\n", + "\n", + "2. **Define Variables:**\n", + " * Let $P$ be the original price.\n", + "\n", + "3. **Step-by-Step Calculation:**\n", + "\n", + " * *Step 1: Apply the 20% discount.*\n", + " * Discount amount = $0.20 \\times P$\n", + " * Final price after discount = $P - 0.20P$\n", + " * Final price after discount = $0.80P$\n", + "\n", + " * *Step 2: Apply the 10% tax.*\n", + " * Tax amount = $0.10 \\times (\\text{Final price after discount})$\n", + " * Tax amount = $0.10 \\times (0.80P)$\n", + " * Tax amount = $0.08P$\n", + " * Final price after tax = Final price after discount + Tax amount\n", + " * Final price after tax = $0.80P + 0.08P$\n", + " * Final price after tax = $0.88P$\n", + "\n", + " * *Step 3: Compare Final Price to Original Price.*\n", + " * Original Price = $P$\n", + " * Final Price = $0.88P$\n", + " * Since $0.88 < 1$, the final price is lower.\n", + "\n", + " * *Step 4: Calculate the difference.*\n", + " * Difference = Final Price - Original Price\n", + " * Difference = $0.88P - P$\n", + " * Difference = $-0.12P$\n", + " * The difference is $0.12P$ (or 12% of the original price).\n", + "\n", + "4. **Verification:**\n", + " * Let's pick a specific number to make sure.\n", + " * Let $P = 100$.\n", + " * \n", + "\n", + "Generation speed: 9.00 tokens/sec\n", + "GPU memory used: 2.56 GB\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "prompt = \"A shop gives a 20% discount, then adds 10% tax. Is the final price higher or lower than the original? By how much?\"\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "start_time = time.perf_counter()\n", + "generated_tokens = 0\n", + "\n", + "for token in generate_text_basic_stream(\n", + " model=model,\n", + " token_ids=input_token_ids_tensor,\n", + " max_new_tokens=500,\n", + " eos_token_id=tokenizer.eos_token_id\n", + "):\n", + " generated_tokens += 1\n", + " token_id = token.squeeze(0).tolist()\n", + " print(\n", + " tokenizer.decode(token_id),\n", + " end=\"\",\n", + " flush=True\n", + " )\n", + "\n", + "elapsed = time.perf_counter() - start_time\n", + "tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n", + "print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n", + "\n", + "if torch.cuda.is_available():\n", + " def calc_gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + "\n", + " print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "549324d6-5c71-4147-ae21-2e67675faa3d", + "metadata": { + "id": "549324d6-5c71-4147-ae21-2e67675faa3d" + }, + "source": [ + " \n", + "# What's next?" + ] + }, + { + "cell_type": "markdown", + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", + "metadata": { + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" + }, + "source": [ + "- Check out the [README.md](../11_qwen3/README.md), to use this model via the `llms_from_scratch` package\n", + "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/16_qwen3.5/qwen3.5.ipynb b/ch05/16_qwen3.5/qwen3.5.ipynb new file mode 100644 index 0000000..83777c3 --- /dev/null +++ b/ch05/16_qwen3.5/qwen3.5.ipynb @@ -0,0 +1,1602 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c", + "metadata": { + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", + "metadata": { + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52" + }, + "source": [ + "# Qwen3.5 From Scratch" + ] + }, + { + "cell_type": "markdown", + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", + "metadata": { + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d" + }, + "source": [ + "- This notebook is purposefully minimal and focuses on a readable re-implementation of the Qwen3.5 text stack for the [Qwen/Qwen3.5-0.8B on Hugging Face](https://huggingface.co/Qwen/Qwen3.5-0.8B) checkpoint that maps it onto the scaffold I used for the other from-scratch implementations in this repo\n", + "- Qwen3.5 alternates `linear_attention` and `full_attention` layers\n", + "- Note that this notebook is not 100% standalone & from-scratch as it re-uses some code (i.e., the `Qwen3_5GatedDeltaNet` for the linear attention layers) from the Hugging Face transformers library; the relevant parts are inside the [qwen3_5_transformers.py](qwen3_5_transformers.py) file" + ] + }, + { + "cell_type": "markdown", + "id": "b304d453-f7da-4e17-8330-3a08a67ae3b1", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "id": "1241a20b-d196-4521-9228-d46954d383e4", + "metadata": {}, + "source": [ + "- Qwen3.5 is based on the Qwen3-Next architecture, which I described in more detail in section [2. (Linear) Attention Hybrids](https://magazine.sebastianraschka.com/i/177848019/2-linear-attention-hybrids) of my [Beyond Standard LLMs](https://magazine.sebastianraschka.com/p/beyond-standard-llms) article" + ] + }, + { + "cell_type": "markdown", + "id": "402a446f-4efe-41f5-acc0-4f8455846aa5", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7c201adb-747e-437b-9a62-442802941e01", + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface_hub version: 1.5.0\n", + "tokenizers version: 0.22.2\n", + "torch version: 2.8.0+cu128\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"tokenizers\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "70a90338-624a-4706-aa55-6b4358070194", + "metadata": {}, + "outputs": [], + "source": [ + "USE_MODEL = \"Qwen3.5-0.8B\"" + ] + }, + { + "cell_type": "markdown", + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", + "metadata": { + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d" + }, + "source": [ + " \n", + "# 1. Architecture code" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "82076c21-9331-4dcd-b017-42b046cf1a60", + "metadata": { + "id": "82076c21-9331-4dcd-b017-42b046cf1a60" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + "\n", + " def forward(self, x):\n", + " x_fc1 = self.fc1(x)\n", + " x_fc2 = self.fc2(x)\n", + " x = nn.functional.silu(x_fc1) * x_fc2\n", + " return self.fc3(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "56715760-37e1-433e-89da-04864c139a9e", + "metadata": {}, + "outputs": [], + "source": [ + "class RMSNorm(nn.Module):\n", + " def __init__(self, emb_dim, eps=1e-6):\n", + " super().__init__()\n", + " self.eps = eps\n", + " # Qwen3.5 uses (1 + weight) scaling with zero init\n", + " self.weight = nn.Parameter(torch.zeros(emb_dim))\n", + "\n", + " def _norm(self, x):\n", + " return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n", + "\n", + " def forward(self, x):\n", + " x_norm = self._norm(x.float())\n", + " x_norm = x_norm * (1.0 + self.weight.float())\n", + " return x_norm.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4b9a346f-5826-4083-9162-abd56afc03f0", + "metadata": { + "id": "4b9a346f-5826-4083-9162-abd56afc03f0" + }, + "outputs": [], + "source": [ + "def compute_rope_params(\n", + " head_dim,\n", + " theta_base=10_000,\n", + " context_length=4096,\n", + " partial_rotary_factor=1.0,\n", + " dtype=torch.float32,\n", + "):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " rotary_dim = int(head_dim * partial_rotary_factor)\n", + " rotary_dim = max(2, rotary_dim - (rotary_dim % 2))\n", + "\n", + " inv_freq = 1.0 / (\n", + " theta_base ** (\n", + " torch.arange(0, rotary_dim, 2, dtype=dtype)[: (rotary_dim // 2)].float() / rotary_dim\n", + " )\n", + " )\n", + "\n", + " positions = torch.arange(context_length, dtype=dtype)\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n", + " angles = torch.cat([angles, angles], dim=1)\n", + "\n", + " cos = torch.cos(angles)\n", + " sin = torch.sin(angles)\n", + "\n", + " return cos, sin\n", + "\n", + "\n", + "def apply_rope(x, cos, sin):\n", + " _, _, seq_len, head_dim = x.shape\n", + " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", + "\n", + " rot_dim = cos.shape[-1]\n", + " if rot_dim > head_dim:\n", + " raise ValueError(f\"RoPE dim {rot_dim} cannot exceed head_dim {head_dim}.\")\n", + "\n", + " x_rot = x[..., :rot_dim]\n", + " x_pass = x[..., rot_dim:]\n", + "\n", + " x1 = x_rot[..., : rot_dim // 2]\n", + " x2 = x_rot[..., rot_dim // 2 :]\n", + "\n", + " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", + " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", + "\n", + " rotated = torch.cat((-x2, x1), dim=-1)\n", + " x_rotated = (x_rot * cos) + (rotated * sin)\n", + "\n", + " x_out = torch.cat([x_rotated, x_pass], dim=-1)\n", + " return x_out.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", + "metadata": { + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" + }, + "outputs": [], + "source": [ + "class GroupedQueryAttention(nn.Module):\n", + " def __init__(\n", + " self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None\n", + " ):\n", + " super().__init__()\n", + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.num_kv_groups = num_kv_groups\n", + " self.group_size = num_heads // num_kv_groups\n", + "\n", + " if head_dim is None:\n", + " assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n", + " head_dim = d_in // num_heads\n", + "\n", + " self.head_dim = head_dim\n", + " self.d_out = num_heads * head_dim\n", + "\n", + " # Qwen3.5 full-attention uses a gated Q projection (2x output dim)\n", + " self.W_query = nn.Linear(d_in, self.d_out * 2, bias=False, dtype=dtype)\n", + " self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n", + "\n", + " self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n", + "\n", + " if qk_norm:\n", + " self.q_norm = RMSNorm(head_dim, eps=1e-6)\n", + " self.k_norm = RMSNorm(head_dim, eps=1e-6)\n", + " else:\n", + " self.q_norm = self.k_norm = None\n", + "\n", + " def forward(self, x, mask, cos, sin):\n", + " b, num_tokens, _ = x.shape\n", + "\n", + " q_and_gate = self.W_query(x)\n", + " q_and_gate = q_and_gate.view(b, num_tokens, self.num_heads, self.head_dim * 2)\n", + " queries, gate = torch.chunk(q_and_gate, 2, dim=-1)\n", + " gate = gate.reshape(b, num_tokens, self.d_out)\n", + "\n", + " keys = self.W_key(x)\n", + " values = self.W_value(x)\n", + "\n", + " queries = queries.transpose(1, 2)\n", + " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + "\n", + " if self.q_norm:\n", + " queries = self.q_norm(queries)\n", + " if self.k_norm:\n", + " keys = self.k_norm(keys)\n", + "\n", + " queries = apply_rope(queries, cos, sin)\n", + " keys = apply_rope(keys, cos, sin)\n", + "\n", + " keys = keys.repeat_interleave(self.group_size, dim=1)\n", + " values = values.repeat_interleave(self.group_size, dim=1)\n", + "\n", + " attn_scores = queries @ keys.transpose(2, 3)\n", + " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n", + " attn_weights = torch.softmax(\n", + " attn_scores * (self.head_dim ** -0.5),\n", + " dim=-1,\n", + " dtype=torch.float32,\n", + " ).to(queries.dtype)\n", + "\n", + " context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n", + "\n", + " # Qwen3.5 full-attention uses a gated Q projection\n", + " context = context * torch.sigmoid(gate)\n", + " return self.out_proj(context)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", + "metadata": { + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" + }, + "outputs": [], + "source": [ + "from qwen3_5_transformers import (\n", + " Qwen3_5GatedDeltaNet,\n", + ")\n", + "\n", + "# Just a mapping for the different naming convention in Hugging Face transformers\n", + "class _Qwen3_5ConfigAdapter:\n", + " def __init__(self, cfg):\n", + " self.hidden_size = cfg[\"emb_dim\"]\n", + " self.linear_num_value_heads = cfg[\"linear_num_value_heads\"]\n", + " self.linear_num_key_heads = cfg[\"linear_num_key_heads\"]\n", + " self.linear_key_head_dim = cfg[\"linear_key_head_dim\"]\n", + " self.linear_value_head_dim = cfg[\"linear_value_head_dim\"]\n", + " self.linear_conv_kernel_dim = cfg[\"linear_conv_kernel_dim\"]\n", + " self.hidden_act = \"silu\"\n", + " self.rms_norm_eps = cfg.get(\"rms_norm_eps\", 1e-6)\n", + " self.dtype = cfg.get(\"dtype\", None)\n", + "\n", + "\n", + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg, layer_type, layer_idx):\n", + " super().__init__()\n", + " self.layer_type = layer_type\n", + "\n", + " if layer_type == \"full_attention\":\n", + " self.token_mixer = GroupedQueryAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " head_dim=cfg[\"head_dim\"],\n", + " num_kv_groups=cfg[\"n_kv_groups\"],\n", + " qk_norm=cfg[\"qk_norm\"],\n", + " dtype=cfg[\"dtype\"],\n", + " )\n", + " elif layer_type == \"linear_attention\":\n", + " self.token_mixer = Qwen3_5GatedDeltaNet(_Qwen3_5ConfigAdapter(cfg), layer_idx)\n", + " else:\n", + " raise ValueError(f\"Unsupported layer type: {layer_type}\")\n", + "\n", + " self.ff = FeedForward(cfg)\n", + " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n", + " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n", + "\n", + " def forward(self, x, mask, cos, sin):\n", + " shortcut = x\n", + " x = self.norm1(x)\n", + "\n", + " if self.layer_type == \"full_attention\":\n", + " x = self.token_mixer(x, mask, cos, sin)\n", + " else:\n", + " x = self.token_mixer(x)\n", + "\n", + " x = x + shortcut\n", + "\n", + " shortcut = x\n", + " x = self.norm2(x)\n", + " x = self.ff(x)\n", + " x = x + shortcut\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", + "metadata": { + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" + }, + "outputs": [], + "source": [ + "class Qwen3_5Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + "\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + "\n", + " layer_types = cfg.get(\"layer_types\", [\"full_attention\"] * cfg[\"n_layers\"])\n", + " if len(layer_types) != cfg[\"n_layers\"]:\n", + " raise ValueError(\"len(layer_types) must equal n_layers\")\n", + "\n", + " self.trf_blocks = nn.ModuleList(\n", + " [TransformerBlock(cfg, layer_type, idx) for idx, layer_type in enumerate(layer_types)]\n", + " )\n", + "\n", + " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + "\n", + " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"] if cfg[\"head_dim\"] is None else cfg[\"head_dim\"]\n", + " cos, sin = compute_rope_params(\n", + " head_dim=head_dim,\n", + " theta_base=cfg[\"rope_base\"],\n", + " context_length=cfg[\"context_length\"],\n", + " partial_rotary_factor=cfg.get(\"partial_rotary_factor\", 1.0),\n", + " dtype=torch.float32,\n", + " )\n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", + " self.cfg = cfg\n", + "\n", + " def forward(self, in_idx):\n", + " x = self.tok_emb(in_idx)\n", + "\n", + " num_tokens = x.shape[1]\n", + " mask = torch.triu(\n", + " torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool),\n", + " diagonal=1,\n", + " )\n", + "\n", + " for block in self.trf_blocks:\n", + " x = block(x, mask, self.cos, self.sin)\n", + "\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", + "metadata": { + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48" + }, + "source": [ + " \n", + "# 2. Initialize model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "caa142fa-b375-4e78-b392-2072ced666f3", + "metadata": { + "id": "caa142fa-b375-4e78-b392-2072ced666f3" + }, + "outputs": [], + "source": [ + "# Qwen3.5-0.8B text configuration\n", + "QWEN3_5_CONFIG = {\n", + " \"vocab_size\": 248_320,\n", + " \"context_length\": 262_144,\n", + " \"emb_dim\": 1_024,\n", + " \"n_heads\": 8,\n", + " \"n_layers\": 24,\n", + " \"hidden_dim\": 3_584,\n", + " \"head_dim\": 256,\n", + " \"qk_norm\": True,\n", + " \"n_kv_groups\": 2,\n", + " \"rope_base\": 10_000_000.0,\n", + " \"partial_rotary_factor\": 0.25,\n", + " \"rms_norm_eps\": 1e-6,\n", + " \"linear_conv_kernel_dim\": 4,\n", + " \"linear_key_head_dim\": 128,\n", + " \"linear_value_head_dim\": 128,\n", + " \"linear_num_key_heads\": 16,\n", + " \"linear_num_value_heads\": 16,\n", + " \"dtype\": torch.bfloat16,\n", + " \"layer_types\": [\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n", + " ],\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", + "metadata": { + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d\n" + ] + } + ], + "source": [ + "torch.manual_seed(123)\n", + "model = Qwen3_5Model(QWEN3_5_CONFIG)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "eaf86265-4e9d-4024-9ed0-99076944e304", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Qwen3_5Model(\n", + " (tok_emb): Embedding(248320, 1024)\n", + " (trf_blocks): ModuleList(\n", + " (0-2): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (3): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (4-6): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (7): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (8-10): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (11): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (12-14): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (15): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (16-18): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (19): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (20-22): 3 x TransformerBlock(\n", + " (token_mixer): Qwen3_5GatedDeltaNet(\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n", + " (norm): Qwen3_5RMSNormGated()\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n", + " (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n", + " (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n", + " (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " (23): TransformerBlock(\n", + " (token_mixer): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=1024, out_features=512, bias=False)\n", + " (W_value): Linear(in_features=1024, out_features=512, bias=False)\n", + " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n", + " (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n", + " )\n", + " (norm1): RMSNorm()\n", + " (norm2): RMSNorm()\n", + " )\n", + " )\n", + " (final_norm): RMSNorm()\n", + " (out_head): Linear(in_features=1024, out_features=248320, bias=False)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b", + "metadata": {}, + "source": [ + "- A quick check that the forward pass works before continuing:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-0.6719, -0.0347, -0.5938, ..., 0.5469, 0.1660, -0.8945],\n", + " [ 0.0391, -0.1226, -0.8789, ..., -0.6523, -0.8281, -0.0889],\n", + " [ 0.1992, -0.7930, -0.3359, ..., -0.6602, 0.0515, -0.1582]]],\n", + " dtype=torch.bfloat16, grad_fn=)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(torch.tensor([1, 2, 3]).unsqueeze(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", + "outputId": "00d7e983-262e-4c65-f322-f4d999311988" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters: 1,006,672,704\n", + "\n", + "Total number of unique parameters: 752,393,024\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total number of parameters: {total_params:,}\")\n", + "\n", + "# Account for weight tying\n", + "total_params_normalized = total_params - model.tok_emb.weight.numel()\n", + "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", + "outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float32 (PyTorch default): 7.63 GB\n", + "bfloat16: 3.81 GB\n" + ] + } + ], + "source": [ + "def calc_model_memory_size(model, input_dtype=torch.float32):\n", + " total_params = 0\n", + " total_grads = 0\n", + " for param in model.parameters():\n", + " # Calculate total number of elements per parameter\n", + " param_size = param.numel()\n", + " total_params += param_size\n", + " # Check if gradients are stored for this parameter\n", + " if param.requires_grad:\n", + " total_grads += param_size\n", + "\n", + " # Calculate buffer size (non-parameters that require memory)\n", + " total_buffers = sum(buf.numel() for buf in model.buffers())\n", + "\n", + " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", + " # We assume parameters and gradients are stored in the same type as input dtype\n", + " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", + " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", + "\n", + " # Convert bytes to gigabytes\n", + " total_memory_gb = total_memory_bytes / (1024**3)\n", + "\n", + " return total_memory_gb\n", + "\n", + "print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", + "print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", + "metadata": { + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "c172f89f-d301-439f-b809-46169e5f5945", + "metadata": { + "id": "c172f89f-d301-439f-b809-46169e5f5945" + }, + "source": [ + " \n", + "# 3. Load pretrained weights" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "75166128-5899-4995-9b88-9672e135650e", + "metadata": { + "id": "75166128-5899-4995-9b88-9672e135650e" + }, + "outputs": [], + "source": [ + "def load_weights_into_qwen3_5(model, param_config, params):\n", + " def assign(left, right, tensor_name=\"unknown\"):\n", + " if left.shape != right.shape:\n", + " raise ValueError(\n", + " f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n", + " )\n", + "\n", + " with torch.no_grad():\n", + " if isinstance(right, torch.Tensor):\n", + " left.copy_(right)\n", + " else:\n", + " left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n", + "\n", + " return left\n", + "\n", + " if \"model.embed_tokens.weight\" in params:\n", + " model_prefix = \"model\"\n", + " elif \"model.language_model.embed_tokens.weight\" in params:\n", + " model_prefix = \"model.language_model\"\n", + " else:\n", + " raise KeyError(\"Could not find embed token weights in checkpoint.\")\n", + "\n", + " def pkey(suffix):\n", + " return f\"{model_prefix}.{suffix}\"\n", + "\n", + " model.tok_emb.weight = assign(\n", + " model.tok_emb.weight,\n", + " params[pkey(\"embed_tokens.weight\")],\n", + " pkey(\"embed_tokens.weight\"),\n", + " )\n", + "\n", + " n_layers = param_config[\"n_layers\"]\n", + " layer_types = param_config.get(\"layer_types\", [\"full_attention\"] * n_layers)\n", + "\n", + " for l in range(n_layers):\n", + " block = model.trf_blocks[l]\n", + " layer_type = layer_types[l]\n", + "\n", + " if layer_type == \"full_attention\":\n", + " att = block.token_mixer\n", + " att.W_query.weight = assign(\n", + " att.W_query.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.q_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.q_proj.weight\"),\n", + " )\n", + " att.W_key.weight = assign(\n", + " att.W_key.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.k_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.k_proj.weight\"),\n", + " )\n", + " att.W_value.weight = assign(\n", + " att.W_value.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.v_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.v_proj.weight\"),\n", + " )\n", + " att.out_proj.weight = assign(\n", + " att.out_proj.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.o_proj.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.o_proj.weight\"),\n", + " )\n", + " if hasattr(att, \"q_norm\") and att.q_norm is not None:\n", + " att.q_norm.weight = assign(\n", + " att.q_norm.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.q_norm.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.q_norm.weight\"),\n", + " )\n", + " if hasattr(att, \"k_norm\") and att.k_norm is not None:\n", + " att.k_norm.weight = assign(\n", + " att.k_norm.weight,\n", + " params[pkey(f\"layers.{l}.self_attn.k_norm.weight\")],\n", + " pkey(f\"layers.{l}.self_attn.k_norm.weight\"),\n", + " )\n", + "\n", + " elif layer_type == \"linear_attention\":\n", + " lat = block.token_mixer\n", + " lat.dt_bias = assign(\n", + " lat.dt_bias,\n", + " params[pkey(f\"layers.{l}.linear_attn.dt_bias\")],\n", + " pkey(f\"layers.{l}.linear_attn.dt_bias\"),\n", + " )\n", + " lat.A_log = assign(\n", + " lat.A_log,\n", + " params[pkey(f\"layers.{l}.linear_attn.A_log\")],\n", + " pkey(f\"layers.{l}.linear_attn.A_log\"),\n", + " )\n", + " lat.conv1d.weight = assign(\n", + " lat.conv1d.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.conv1d.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.conv1d.weight\"),\n", + " )\n", + " lat.norm.weight = assign(\n", + " lat.norm.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.norm.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.norm.weight\"),\n", + " )\n", + " lat.out_proj.weight = assign(\n", + " lat.out_proj.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.out_proj.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.out_proj.weight\"),\n", + " )\n", + " lat.in_proj_qkv.weight = assign(\n", + " lat.in_proj_qkv.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_qkv.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_qkv.weight\"),\n", + " )\n", + " lat.in_proj_z.weight = assign(\n", + " lat.in_proj_z.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_z.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_z.weight\"),\n", + " )\n", + " lat.in_proj_b.weight = assign(\n", + " lat.in_proj_b.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_b.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_b.weight\"),\n", + " )\n", + " lat.in_proj_a.weight = assign(\n", + " lat.in_proj_a.weight,\n", + " params[pkey(f\"layers.{l}.linear_attn.in_proj_a.weight\")],\n", + " pkey(f\"layers.{l}.linear_attn.in_proj_a.weight\"),\n", + " )\n", + "\n", + " else:\n", + " raise ValueError(f\"Unsupported layer type: {layer_type}\")\n", + "\n", + " block.norm1.weight = assign(\n", + " block.norm1.weight,\n", + " params[pkey(f\"layers.{l}.input_layernorm.weight\")],\n", + " pkey(f\"layers.{l}.input_layernorm.weight\"),\n", + " )\n", + "\n", + " block.ff.fc1.weight = assign(\n", + " block.ff.fc1.weight,\n", + " params[pkey(f\"layers.{l}.mlp.gate_proj.weight\")],\n", + " pkey(f\"layers.{l}.mlp.gate_proj.weight\"),\n", + " )\n", + " block.ff.fc2.weight = assign(\n", + " block.ff.fc2.weight,\n", + " params[pkey(f\"layers.{l}.mlp.up_proj.weight\")],\n", + " pkey(f\"layers.{l}.mlp.up_proj.weight\"),\n", + " )\n", + " block.ff.fc3.weight = assign(\n", + " block.ff.fc3.weight,\n", + " params[pkey(f\"layers.{l}.mlp.down_proj.weight\")],\n", + " pkey(f\"layers.{l}.mlp.down_proj.weight\"),\n", + " )\n", + " block.norm2.weight = assign(\n", + " block.norm2.weight,\n", + " params[pkey(f\"layers.{l}.post_attention_layernorm.weight\")],\n", + " pkey(f\"layers.{l}.post_attention_layernorm.weight\"),\n", + " )\n", + "\n", + " model.final_norm.weight = assign(\n", + " model.final_norm.weight,\n", + " params[pkey(\"norm.weight\")],\n", + " pkey(\"norm.weight\"),\n", + " )\n", + "\n", + " if \"lm_head.weight\" in params:\n", + " model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n", + " elif pkey(\"lm_head.weight\") in params:\n", + " model.out_head.weight = assign(model.out_head.weight, params[pkey(\"lm_head.weight\")], pkey(\"lm_head.weight\"))\n", + " else:\n", + " model.out_head.weight = model.tok_emb.weight\n", + " print(\"Model uses weight tying.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "9881b6995c3f49dc89e6992fd9ab660b", + "17a3174e65c54476b2e0d1faf8f011ca", + "1bbf2e62c0754d1593beb4105a7f1ac1", + "b82112e1dec645d98aa1c1ba64abcb61", + "271e2bd6a35e4a8b92de8697f7c0be5f", + "90a79523187446dfa692723b2e5833a7", + "431ffb83b8c14bf182f0430e07ea6154", + "a8f1b72a33dd4b548de23fbd95e0da18", + "25cc36132d384189acfbecc59483134b", + "bfd06423ad544218968648016e731a46", + "d029630b63ff44cf807ade428d2eb421" + ] + }, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f6b15da62e96419890bc93ade1dbabe3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading (incomplete total...): 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "34b073857bf447cfaa9aa81f141b9f59", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 13 files: 0%| | 0/13 [00:00\",\n", + " \"<|im_start|>\", \"<|im_end|>\",\n", + " \"<|object_ref_start|>\", \"<|object_ref_end|>\",\n", + " \"<|box_start|>\", \"<|box_end|>\",\n", + " \"<|quad_start|>\", \"<|quad_end|>\",\n", + " \"<|vision_start|>\", \"<|vision_end|>\",\n", + " \"<|vision_pad|>\", \"<|image_pad|>\", \"<|video_pad|>\",\n", + " \"\", \"\",\n", + " ]\n", + " _SPLIT_RE = re.compile(r\"(<\\|[^>]+?\\|>||)\")\n", + "\n", + " def __init__(\n", + " self,\n", + " tokenizer_file_path=\"tokenizer.json\",\n", + " repo_id=None,\n", + " apply_chat_template=True,\n", + " add_generation_prompt=False,\n", + " add_thinking=False,\n", + " ):\n", + " self.apply_chat_template = apply_chat_template\n", + " self.add_generation_prompt = add_generation_prompt\n", + " self.add_thinking = add_thinking\n", + "\n", + " tok_file = Path(tokenizer_file_path)\n", + " self._tok = Tokenizer.from_file(str(tok_file))\n", + " self._special_to_id = {}\n", + " for t in self._SPECIALS:\n", + " tid = self._tok.token_to_id(t)\n", + " if tid is not None:\n", + " self._special_to_id[t] = tid\n", + "\n", + " self.pad_token_id = self._special_to_id[\"<|endoftext|>\"]\n", + " self.eos_token_id = self.pad_token_id\n", + "\n", + " if repo_id and \"Base\" not in repo_id:\n", + " eos_token = \"<|im_end|>\"\n", + " else:\n", + " eos_token = \"<|endoftext|>\"\n", + " if eos_token in self._special_to_id:\n", + " self.eos_token_id = self._special_to_id[eos_token]\n", + "\n", + " def encode(self, text, chat_wrapped=None):\n", + " if chat_wrapped is None:\n", + " chat_wrapped = self.apply_chat_template\n", + "\n", + " stripped = text.strip()\n", + " if stripped in self._special_to_id and \"\\n\" not in stripped:\n", + " return [self._special_to_id[stripped]]\n", + "\n", + " if chat_wrapped:\n", + " text = self._wrap_chat(text)\n", + "\n", + " ids = []\n", + " for part in filter(None, self._SPLIT_RE.split(text)):\n", + " if part in self._special_to_id:\n", + " ids.append(self._special_to_id[part])\n", + " else:\n", + " ids.extend(self._tok.encode(part).ids)\n", + " return ids\n", + "\n", + " def decode(self, ids):\n", + " return self._tok.decode(ids, skip_special_tokens=False)\n", + "\n", + " def _wrap_chat(self, user_msg):\n", + " # Mirrors Qwen3.5 chat_template behavior:\n", + " # add_generation_prompt + thinking => \"\\n\"\n", + " # add_generation_prompt + no thinking => empty think scaffold\n", + " s = f\"<|im_start|>user\\n{user_msg}<|im_end|>\\n\"\n", + " if self.add_generation_prompt:\n", + " s += \"<|im_start|>assistant\\n\"\n", + " if self.add_thinking:\n", + " s += \"\\n\"\n", + " else:\n", + " s += \"\\n\\n\\n\\n\"\n", + " return s" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_file_path = \"Qwen3.5-0.8B/tokenizer.json\"\n", + "\n", + "hf_hub_download(\n", + " repo_id=repo_id,\n", + " filename=\"tokenizer.json\",\n", + " local_dir=local_dir,\n", + ")\n", + "\n", + "tokenizer = Qwen3_5Tokenizer(\n", + " tokenizer_file_path=tokenizer_file_path,\n", + " repo_id=repo_id,\n", + " apply_chat_template=True,\n", + " add_generation_prompt=True,\n", + " add_thinking=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1946b534-e3af-431a-a222-391a60bfa892", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n\\n'" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = \"Give me a short introduction to large language models.\"\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "text = tokenizer.decode(input_token_ids)\n", + "text" + ] + }, + { + "cell_type": "markdown", + "id": "57d07df1-4401-4792-b549-7c4cc5632323", + "metadata": { + "id": "57d07df1-4401-4792-b549-7c4cc5632323" + }, + "source": [ + " \n", + "# 4. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", + "metadata": { + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" + }, + "outputs": [], + "source": [ + "def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None):\n", + "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for _ in range(max_new_tokens):\n", + " out = model(token_ids)[:, -1]\n", + " next_token = torch.argmax(out, dim=-1, keepdim=True)\n", + "\n", + " if (eos_token_id is not None\n", + " and torch.all(next_token == eos_token_id)):\n", + " break\n", + "\n", + " yield next_token\n", + " \n", + " token_ids = torch.cat([token_ids, next_token], dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", + "metadata": { + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Thinking Process:\n", + "\n", + "1. **Analyze the Request:**\n", + " * **Topic:** Large Language Models (LLMs).\n", + " * **Task:** Give a short introduction.\n", + " * **Constraint:** \"Short\" (implies concise, clear, and impactful).\n", + "\n", + "2. **Identify Key Concepts:**\n", + " * What are they? (AI models trained on massive datasets).\n", + " * What do they do? (Generate text, code, etc.).\n", + " * How do they work? (Neural networks, transformers, training).\n", + " * Why are they important? (Efficiency, context, creativity).\n", + " * *Self-Correction/Refinement:* Keep it simple but accurate. Avoid overly technical jargon unless necessary, but \"transformers\" is a key term.\n", + "\n", + "3. **Drafting - Attempt 1 (Mental Outline):**\n", + " LLMs are big AI models. They are trained on huge amounts of data. They can understand and generate text. They are like a supercomputer for language. They are used in chatbots and coding.\n", + "\n", + "4. **Drafting - Attempt 2 (Adding Detail & Flow):**\n", + " Large Language Models (LLMs) are a type of artificial intelligence. They are trained on massive datasets of text. They use neural networks to understand and generate human-like text. They are used in chatbots, coding assistants, and creative writing. They are becoming more powerful and efficient.\n", + "\n", + "5. **Drafting - Attempt 3 (Polishing for \"Short Introduction\"):**\n", + " Large Language Models (LLMs) are a type of artificial intelligence that can understand and generate human-like text. They are trained on massive datasets of text. They use neural networks to process information and create content. They are used in chatbots, coding assistants, and creative writing. They are becoming more powerful and efficient.\n", + "\n", + "6. **Refining for Clarity and Impact:**\n", + " * Make it punchy.\n", + " * Highlight the \"transformers\" or \"neural networks\" aspect if needed, but keep it simple.\n", + " * Mention the \"big data\" aspect.\n", + "\n", + "7. **Final Polish (incorporating into the final output):**\n", + " * Start with a definition.\n", + " * Mention the core technology (neural networks).\n", + " * Mention the output\n", + "\n", + "Generation speed: 8.28 tokens/sec\n", + "GPU memory used: 2.54 GB\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "prompt = \"Give me a short introduction to large language models.\"\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "start_time = time.perf_counter()\n", + "generated_tokens = 0\n", + "\n", + "for token in generate_text_basic_stream(\n", + " model=model,\n", + " token_ids=input_token_ids_tensor,\n", + " max_new_tokens=500,\n", + " eos_token_id=tokenizer.eos_token_id\n", + "):\n", + " generated_tokens += 1\n", + " token_id = token.squeeze(0).tolist()\n", + " print(\n", + " tokenizer.decode(token_id),\n", + " end=\"\",\n", + " flush=True\n", + " )\n", + "\n", + "elapsed = time.perf_counter() - start_time\n", + "tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n", + "print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n", + "\n", + "if torch.cuda.is_available():\n", + " def calc_gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + "\n", + " print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b0ef78d8-e512-47c2-aaab-d236a6e7cad3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Here's" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " a thinking process that leads to the solution:\n", + "\n", + "1. **Analyze the Request:**\n", + " * **Scenario:** A shop applies two discounts and a tax.\n", + " * **Discount:** 20% off the original price.\n", + " * **Tax:** 10% added on top of the discounted price.\n", + " * **Question:** Is the final price higher or lower than the original? By how much?\n", + "\n", + "2. **Define Variables:**\n", + " * Let $P$ be the original price.\n", + "\n", + "3. **Step-by-Step Calculation:**\n", + "\n", + " * *Step 1: Apply the 20% discount.*\n", + " * Discount amount = $0.20 \\times P$\n", + " * Final price after discount = $P - 0.20P$\n", + " * Final price after discount = $0.80P$\n", + "\n", + " * *Step 2: Apply the 10% tax.*\n", + " * Tax amount = $0.10 \\times (\\text{Final price after discount})$\n", + " * Tax amount = $0.10 \\times (0.80P)$\n", + " * Tax amount = $0.08P$\n", + " * Final price after tax = Final price after discount + Tax amount\n", + " * Final price after tax = $0.80P + 0.08P$\n", + " * Final price after tax = $0.88P$\n", + "\n", + " * *Step 3: Compare Final Price to Original Price.*\n", + " * Original Price = $P$\n", + " * Final Price = $0.88P$\n", + " * Since $0.88 < 1$, the final price is lower.\n", + "\n", + " * *Step 4: Calculate the difference.*\n", + " * Difference = Final Price - Original Price\n", + " * Difference = $0.88P - P$\n", + " * Difference = $-0.12P$\n", + " * The difference is $0.12P$ (or 12% of the original price).\n", + "\n", + "4. **Verification:**\n", + " * Let's pick a specific number to make sure.\n", + " * Let $P = 100$.\n", + " * \n", + "\n", + "Generation speed: 9.02 tokens/sec\n", + "GPU memory used: 2.56 GB\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "prompt = \"A shop gives a 20% discount, then adds 10% tax. Is the final price higher or lower than the original? By how much?\"\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "start_time = time.perf_counter()\n", + "generated_tokens = 0\n", + "\n", + "for token in generate_text_basic_stream(\n", + " model=model,\n", + " token_ids=input_token_ids_tensor,\n", + " max_new_tokens=500,\n", + " eos_token_id=tokenizer.eos_token_id\n", + "):\n", + " generated_tokens += 1\n", + " token_id = token.squeeze(0).tolist()\n", + " print(\n", + " tokenizer.decode(token_id),\n", + " end=\"\",\n", + " flush=True\n", + " )\n", + "\n", + "elapsed = time.perf_counter() - start_time\n", + "tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n", + "print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n", + "\n", + "if torch.cuda.is_available():\n", + " def calc_gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + "\n", + " print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "549324d6-5c71-4147-ae21-2e67675faa3d", + "metadata": { + "id": "549324d6-5c71-4147-ae21-2e67675faa3d" + }, + "source": [ + " \n", + "# What's next?" + ] + }, + { + "cell_type": "markdown", + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", + "metadata": { + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" + }, + "source": [ + "- Check out the [README.md](../11_qwen3/README.md), to use this model via the `llms_from_scratch` package\n", + "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/16_qwen3.5/qwen3_5_transformers.py b/ch05/16_qwen3.5/qwen3_5_transformers.py new file mode 100644 index 0000000..a961db6 --- /dev/null +++ b/ch05/16_qwen3.5/qwen3_5_transformers.py @@ -0,0 +1,425 @@ +"""Qwen3.5 helper blocks copied from Hugging Face Transformers + +Source file: +transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py + +License: Apache License Version 2.0 +License URL: https://github.com/huggingface/transformers/blob/main/LICENSE +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Notebook shims for optional fast kernels in transformers +causal_conv1d_fn = None +causal_conv1d_update = None +chunk_gated_delta_rule = None +fused_recurrent_gated_delta_rule = None +FusedRMSNormGated = None +ACT2FN = {"silu": F.silu} +is_fast_path_available = False + + +class _NotebookLogger: + def __init__(self): + self._seen = set() + + def warning_once(self, msg): + if msg in self._seen: + return + self._seen.add(msg) + print(msg) + + +logger = _NotebookLogger() + + +# Placeholder types for copied annotations +class Qwen3_5Config: + pass + + +class Qwen3_5DynamicCache: + pass + + +# Copied verbatim from: +# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +class Qwen3_5RMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) + + +# Copied verbatim from: +# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + # NOTE: attention mask is a 2D boolean tensor + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Copied verbatim from: +# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +def torch_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, +): + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype) + return out + + +# Copied verbatim from: +# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +def l2norm(x, dim=-1, eps=1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +# Copied verbatim from: +# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +# Copied verbatim from: +# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +def torch_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +# Copied from: +# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +# Minimal change: enforce config dtype at the end to avoid bf16/fp32 matmul mismatch +# in a mixed notebook implementation +class Qwen3_5GatedDeltaNet(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.norm = ( + Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + if FusedRMSNormGated is None + else FusedRMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), + ) + ) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.causal_conv1d_fn = causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of the required library is not installed. Falling back to " + "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + # Notebook adaptation for dtype consistency. + if config.dtype is not None: + self.to(dtype=config.dtype) + + def forward( + self, + hidden_states, + cache_params=None, + cache_position=None, + attention_mask=None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output diff --git a/ch05/16_qwen3.5/tests/qwen3_5_layer_debugger.py b/ch05/16_qwen3.5/tests/qwen3_5_layer_debugger.py new file mode 100644 index 0000000..733a424 --- /dev/null +++ b/ch05/16_qwen3.5/tests/qwen3_5_layer_debugger.py @@ -0,0 +1,275 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import sys +from pathlib import Path + +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + + +def _import_qwen3_5_classes(): + try: + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + return Qwen3_5TextConfig, Qwen3_5ForCausalLM + except Exception: + repo_root = Path(__file__).resolve().parents[3] + local_src = repo_root / "transformers-main" / "src" + if not local_src.exists(): + raise + + for name in list(sys.modules): + if name == "transformers" or name.startswith("transformers."): + del sys.modules[name] + sys.path.insert(0, str(local_src)) + + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + return Qwen3_5TextConfig, Qwen3_5ForCausalLM + + +try: + Qwen3_5TextConfig, Qwen3_5ForCausalLM = _import_qwen3_5_classes() +except Exception: + Qwen3_5TextConfig = None + Qwen3_5ForCausalLM = None + + +def tiny_debug_config(): + return { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_groups": 2, + "rope_base": 1_000_000.0, + "partial_rotary_factor": 1.0, + "rms_norm_eps": 1e-6, + "linear_conv_kernel_dim": 2, + "linear_key_head_dim": 8, + "linear_value_head_dim": 8, + "linear_num_key_heads": 2, + "linear_num_value_heads": 2, + "layer_types": ["linear_attention", "full_attention"], + "dtype": torch.float32, + } + + +def _hf_config_from_dict(cfg): + if Qwen3_5TextConfig is None: + raise ImportError("Qwen3.5 classes are required for the layer debugger.") + + hf_cfg = Qwen3_5TextConfig( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_groups"], + layer_types=cfg["layer_types"], + linear_conv_kernel_dim=cfg["linear_conv_kernel_dim"], + linear_key_head_dim=cfg["linear_key_head_dim"], + linear_value_head_dim=cfg["linear_value_head_dim"], + linear_num_key_heads=cfg["linear_num_key_heads"], + linear_num_value_heads=cfg["linear_num_value_heads"], + tie_word_embeddings=False, + use_cache=False, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=cfg.get("rms_norm_eps", 1e-6), + rope_parameters={ + "rope_type": "default", + "rope_theta": cfg["rope_base"], + "partial_rotary_factor": cfg.get("partial_rotary_factor", 1.0), + "mrope_interleaved": True, + "mrope_section": [2, 1, 1], + }, + torch_dtype=cfg.get("dtype", torch.float32), + ) + hf_cfg._attn_implementation = "eager" + return hf_cfg + + +def load_notebook_defs(nb_name="qwen3.5.ipynb"): + nb_dir = Path(__file__).resolve().parents[1] + if str(nb_dir) not in sys.path: + sys.path.insert(0, str(nb_dir)) + return import_definitions_from_notebook(nb_dir, nb_name) + + +def build_qwen3_5_pair(import_notebook_defs, cfg, hf_checkpoint=None): + if Qwen3_5ForCausalLM is None: + raise ImportError("Qwen3.5 classes are required for the layer debugger.") + + ours = import_notebook_defs.Qwen3_5Model(cfg) + + if hf_checkpoint: + hf_model = Qwen3_5ForCausalLM.from_pretrained( + hf_checkpoint, + torch_dtype=cfg.get("dtype", torch.float32), + attn_implementation="eager", + ) + else: + hf_cfg = _hf_config_from_dict(cfg) + hf_model = Qwen3_5ForCausalLM(hf_cfg) + + import_notebook_defs.load_weights_into_qwen3_5( + ours, + {"n_layers": cfg["n_layers"], "layer_types": cfg["layer_types"]}, + hf_model.state_dict(), + ) + hf_model.config.use_cache = False + + ours.eval() + hf_model.eval() + return ours, hf_model + + +def _attach_debug_hooks(model, is_hf): + traces = {} + handles = [] + + def hook(name): + def _record(_, __, output): + if isinstance(output, tuple): + output = output[0] + traces[name] = output.detach().to(torch.float32).cpu() + + return _record + + if is_hf: + core = model.model + handles.append(core.embed_tokens.register_forward_hook(hook("embedding"))) + for idx, layer in enumerate(core.layers): + handles.append(layer.register_forward_hook(hook(f"block_{idx}"))) + handles.append(core.norm.register_forward_hook(hook("final_norm"))) + handles.append(model.lm_head.register_forward_hook(hook("logits"))) + else: + handles.append(model.tok_emb.register_forward_hook(hook("embedding"))) + blocks = getattr(model, "trf_blocks", None) + if blocks is None: + blocks = getattr(model, "blocks", None) + if blocks is None: + raise AttributeError("Could not locate Qwen3.5 blocks on the local model.") + for idx, block in enumerate(blocks): + handles.append(block.register_forward_hook(hook(f"block_{idx}"))) + handles.append(model.final_norm.register_forward_hook(hook("final_norm"))) + handles.append(model.out_head.register_forward_hook(hook("logits"))) + + return traces, handles + + +def _layer_sort_key(name): + if name == "embedding": + return (0, 0) + if name.startswith("block_"): + idx = int(name.split("_")[1]) + return (1, idx) + if name == "final_norm": + return (2, 0) + if name == "logits": + return (3, 0) + return (4, name) + + +def layerwise_differences(ours, hf_model, input_ids, rtol=1e-5, atol=1e-5): + ours_traces, ours_handles = _attach_debug_hooks(ours, is_hf=False) + hf_traces, hf_handles = _attach_debug_hooks(hf_model, is_hf=True) + + try: + with torch.inference_mode(): + ours(input_ids) + hf_model(input_ids, use_cache=False) + finally: + for h in ours_handles + hf_handles: + h.remove() + + layer_names = sorted(set(ours_traces) | set(hf_traces), key=_layer_sort_key) + results = [] + for name in layer_names: + ours_tensor = ours_traces.get(name) + hf_tensor = hf_traces.get(name) + + if ours_tensor is None or hf_tensor is None: + results.append( + { + "name": name, + "status": "missing", + "ours_shape": None if ours_tensor is None else tuple(ours_tensor.shape), + "hf_shape": None if hf_tensor is None else tuple(hf_tensor.shape), + "max_diff": None, + "mean_abs_diff": None, + } + ) + continue + + if ours_tensor.shape != hf_tensor.shape: + results.append( + { + "name": name, + "status": "shape_mismatch", + "ours_shape": tuple(ours_tensor.shape), + "hf_shape": tuple(hf_tensor.shape), + "max_diff": None, + "mean_abs_diff": None, + } + ) + continue + + diff = (ours_tensor - hf_tensor).abs() + max_diff = float(diff.max().item()) + mean_diff = float(diff.mean().item()) + allclose = torch.allclose(ours_tensor, hf_tensor, rtol=rtol, atol=atol) + results.append( + { + "name": name, + "status": "ok" if allclose else "mismatch", + "ours_shape": tuple(ours_tensor.shape), + "hf_shape": tuple(hf_tensor.shape), + "max_diff": max_diff, + "mean_abs_diff": mean_diff, + } + ) + return results + + +def format_report(differences): + lines = [] + for diff in sorted(differences, key=lambda d: _layer_sort_key(d["name"])): + if diff["status"] == "ok": + lines.append(f"[OK] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}") + elif diff["status"] == "mismatch": + lines.append(f"[DIFF] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}") + elif diff["status"] == "shape_mismatch": + lines.append(f"[SHAPE] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}") + else: + lines.append(f"[MISSING] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}") + return "\n".join(lines) + + +if __name__ == "__main__": + if Qwen3_5ForCausalLM is None: + raise SystemExit( + "Qwen3.5 classes are unavailable. Install a recent transformers version or use local transformers-main." + ) + + import_notebook_defs = load_notebook_defs() + cfg = tiny_debug_config() + + ours_model, hf_model = build_qwen3_5_pair(import_notebook_defs, cfg) + torch.manual_seed(0) + input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long) + diffs = layerwise_differences(ours_model, hf_model, input_ids) + print(format_report(diffs)) diff --git a/ch05/16_qwen3.5/tests/test_qwen3_5_nb.py b/ch05/16_qwen3.5/tests/test_qwen3_5_nb.py new file mode 100644 index 0000000..91a0892 --- /dev/null +++ b/ch05/16_qwen3.5/tests/test_qwen3_5_nb.py @@ -0,0 +1,166 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import importlib +import sys +from pathlib import Path + +import pytest +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + + +def _import_qwen3_5_classes(): + try: + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + return Qwen3_5TextConfig, Qwen3_5ForCausalLM + except Exception: + repo_root = Path(__file__).resolve().parents[3] + local_src = repo_root / "transformers-main" / "src" + if not local_src.exists(): + raise + + for name in list(sys.modules): + if name == "transformers" or name.startswith("transformers."): + del sys.modules[name] + sys.path.insert(0, str(local_src)) + + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + return Qwen3_5TextConfig, Qwen3_5ForCausalLM + + +transformers_installed = importlib.util.find_spec("transformers") is not None +if transformers_installed: + try: + Qwen3_5TextConfig, Qwen3_5ForCausalLM = _import_qwen3_5_classes() + except Exception: + transformers_installed = False + Qwen3_5TextConfig, Qwen3_5ForCausalLM = None, None +else: + Qwen3_5TextConfig, Qwen3_5ForCausalLM = None, None + + +@pytest.fixture +def import_notebook_defs(): + nb_dir = Path(__file__).resolve().parents[1] + if str(nb_dir) not in sys.path: + sys.path.insert(0, str(nb_dir)) + + mod = import_definitions_from_notebook(nb_dir, "qwen3.5.ipynb") + return mod + + +@pytest.fixture +def dummy_input(): + torch.manual_seed(123) + return torch.randint(0, 100, (1, 8)) + + +@pytest.fixture +def dummy_cfg_base(): + return { + "vocab_size": 100, + "emb_dim": 32, + "hidden_dim": 64, + "n_layers": 2, + "n_heads": 4, + "head_dim": 8, + "n_kv_groups": 1, + "qk_norm": False, + "dtype": torch.float32, + "rope_base": 10_000.0, + "context_length": 64, + "partial_rotary_factor": 1.0, + "rms_norm_eps": 1e-6, + "linear_conv_kernel_dim": 2, + "linear_key_head_dim": 8, + "linear_value_head_dim": 8, + "linear_num_key_heads": 2, + "linear_num_value_heads": 2, + "layer_types": ["linear_attention", "full_attention"], + } + + +@torch.inference_mode() +def test_dummy_qwen3_5_forward(dummy_cfg_base, dummy_input, import_notebook_defs): + torch.manual_seed(123) + model = import_notebook_defs.Qwen3_5Model(dummy_cfg_base) + out = model(dummy_input) + assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), ( + f"Expected shape (1, seq_len, vocab_size), got {out.shape}" + ) + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_qwen3_5_base_equivalence_with_transformers(import_notebook_defs): + cfg = { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_groups": 2, + "rope_base": 1_000_000.0, + "partial_rotary_factor": 1.0, + "rms_norm_eps": 1e-6, + "linear_conv_kernel_dim": 2, + "linear_key_head_dim": 8, + "linear_value_head_dim": 8, + "linear_num_key_heads": 2, + "linear_num_value_heads": 2, + "layer_types": ["linear_attention", "full_attention"], + "dtype": torch.float32, + } + model = import_notebook_defs.Qwen3_5Model(cfg) + + hf_cfg = Qwen3_5TextConfig( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_groups"], + layer_types=cfg["layer_types"], + linear_conv_kernel_dim=cfg["linear_conv_kernel_dim"], + linear_key_head_dim=cfg["linear_key_head_dim"], + linear_value_head_dim=cfg["linear_value_head_dim"], + linear_num_key_heads=cfg["linear_num_key_heads"], + linear_num_value_heads=cfg["linear_num_value_heads"], + tie_word_embeddings=False, + use_cache=False, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=cfg["rms_norm_eps"], + rope_parameters={ + "rope_type": "default", + "rope_theta": cfg["rope_base"], + "partial_rotary_factor": cfg["partial_rotary_factor"], + "mrope_interleaved": True, + "mrope_section": [2, 1, 1], + }, + torch_dtype=torch.float32, + ) + hf_cfg._attn_implementation = "eager" + hf_model = Qwen3_5ForCausalLM(hf_cfg) + + hf_state = hf_model.state_dict() + param_config = {"n_layers": cfg["n_layers"], "layer_types": cfg["layer_types"]} + import_notebook_defs.load_weights_into_qwen3_5(model, param_config, hf_state) + + x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long) + ours_logits = model(x) + theirs_logits = hf_model(x, use_cache=False).logits + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)