mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
1715 lines
67 KiB
Plaintext
1715 lines
67 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
|
|
"metadata": {
|
|
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
|
|
},
|
|
"source": [
|
|
"<table style=\"width:100%\">\n",
|
|
"<tr>\n",
|
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|
"<font size=\"2\">\n",
|
|
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
|
|
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
|
|
"</font>\n",
|
|
"</td>\n",
|
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
|
|
"</td>\n",
|
|
"</tr>\n",
|
|
"</table>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
|
|
"metadata": {
|
|
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
|
|
},
|
|
"source": [
|
|
"# Qwen3.5 From Scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
|
|
"metadata": {
|
|
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
|
|
},
|
|
"source": [
|
|
"- This notebook is purposefully minimal and focuses on a readable re-implementation of the Qwen3.5 text stack for the [Qwen/Qwen3.5-0.8B on Hugging Face](https://huggingface.co/Qwen/Qwen3.5-0.8B) checkpoint that maps it onto the scaffold I used for the other from-scratch implementations in this repo\n",
|
|
"- Qwen3.5 alternates `linear_attention` and `full_attention` layers\n",
|
|
"- Note that this notebook is not 100% standalone & from-scratch as it re-uses some code (i.e., the `Qwen3_5GatedDeltaNet` for the linear attention layers) from the Hugging Face transformers library; the relevant parts are inside the [qwen3_5_transformers.py](qwen3_5_transformers.py) file"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b304d453-f7da-4e17-8330-3a08a67ae3b1",
|
|
"metadata": {},
|
|
"source": [
|
|
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen3.5/01.webp\" width=\"500px\">"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1241a20b-d196-4521-9228-d46954d383e4",
|
|
"metadata": {},
|
|
"source": [
|
|
"- Qwen3.5 is based on the Qwen3-Next architecture, which I described in more detail in section [2. (Linear) Attention Hybrids](https://magazine.sebastianraschka.com/i/177848019/2-linear-attention-hybrids) of my [Beyond Standard LLMs](https://magazine.sebastianraschka.com/p/beyond-standard-llms) article"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "21d38944-0c98-40a6-a6f8-c745769b4618",
|
|
"metadata": {},
|
|
"source": [
|
|
"<a href=\"https://magazine.sebastianraschka.com/p/beyond-standard-llms\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen3.5/02.webp\" width=\"500px\"></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "7c201adb-747e-437b-9a62-442802941e01",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
|
|
"outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"huggingface_hub version: 1.5.0\n",
|
|
"tokenizers version: 0.22.2\n",
|
|
"torch version: 2.8.0+cu128\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from importlib.metadata import version\n",
|
|
"\n",
|
|
"pkgs = [\n",
|
|
" \"huggingface_hub\", # to download pretrained weights\n",
|
|
" \"tokenizers\", # to implement the tokenizer\n",
|
|
" \"torch\", # to implement the model\n",
|
|
"]\n",
|
|
"for p in pkgs:\n",
|
|
" print(f\"{p} version: {version(p)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "70a90338-624a-4706-aa55-6b4358070194",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"USE_MODEL = \"Qwen3.5-0.8B\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
|
|
"metadata": {
|
|
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
|
|
},
|
|
"source": [
|
|
" \n",
|
|
"# 1. Architecture code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
|
|
"metadata": {
|
|
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"\n",
|
|
"\n",
|
|
"class FeedForward(nn.Module):\n",
|
|
" def __init__(self, cfg):\n",
|
|
" super().__init__()\n",
|
|
" self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|
" self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|
" self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x_fc1 = self.fc1(x)\n",
|
|
" x_fc2 = self.fc2(x)\n",
|
|
" x = nn.functional.silu(x_fc1) * x_fc2\n",
|
|
" return self.fc3(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "56715760-37e1-433e-89da-04864c139a9e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class RMSNorm(nn.Module):\n",
|
|
" def __init__(self, emb_dim, eps=1e-6):\n",
|
|
" super().__init__()\n",
|
|
" self.eps = eps\n",
|
|
" # Qwen3.5 uses (1 + weight) scaling with zero init\n",
|
|
" self.weight = nn.Parameter(torch.zeros(emb_dim))\n",
|
|
"\n",
|
|
" def _norm(self, x):\n",
|
|
" return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x_norm = self._norm(x.float())\n",
|
|
" x_norm = x_norm * (1.0 + self.weight.float())\n",
|
|
" return x_norm.to(dtype=x.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
|
|
"metadata": {
|
|
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def compute_rope_params(\n",
|
|
" head_dim,\n",
|
|
" theta_base=10_000,\n",
|
|
" context_length=4096,\n",
|
|
" partial_rotary_factor=1.0,\n",
|
|
" dtype=torch.float32,\n",
|
|
"):\n",
|
|
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
|
|
"\n",
|
|
" rotary_dim = int(head_dim * partial_rotary_factor)\n",
|
|
" rotary_dim = max(2, rotary_dim - (rotary_dim % 2))\n",
|
|
"\n",
|
|
" inv_freq = 1.0 / (\n",
|
|
" theta_base ** (\n",
|
|
" torch.arange(0, rotary_dim, 2, dtype=dtype)[: (rotary_dim // 2)].float() / rotary_dim\n",
|
|
" )\n",
|
|
" )\n",
|
|
"\n",
|
|
" positions = torch.arange(context_length, dtype=dtype)\n",
|
|
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n",
|
|
" angles = torch.cat([angles, angles], dim=1)\n",
|
|
"\n",
|
|
" cos = torch.cos(angles)\n",
|
|
" sin = torch.sin(angles)\n",
|
|
"\n",
|
|
" return cos, sin\n",
|
|
"\n",
|
|
"\n",
|
|
"def apply_rope(x, cos, sin, offset=0):\n",
|
|
" _, _, seq_len, head_dim = x.shape\n",
|
|
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
|
|
"\n",
|
|
" rot_dim = cos.shape[-1]\n",
|
|
" if rot_dim > head_dim:\n",
|
|
" raise ValueError(f\"RoPE dim {rot_dim} cannot exceed head_dim {head_dim}.\")\n",
|
|
"\n",
|
|
" x_rot = x[..., :rot_dim]\n",
|
|
" x_pass = x[..., rot_dim:]\n",
|
|
"\n",
|
|
" x1 = x_rot[..., : rot_dim // 2]\n",
|
|
" x2 = x_rot[..., rot_dim // 2 :]\n",
|
|
"\n",
|
|
" cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n",
|
|
" sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n",
|
|
"\n",
|
|
" rotated = torch.cat((-x2, x1), dim=-1)\n",
|
|
" x_rotated = (x_rot * cos) + (rotated * sin)\n",
|
|
"\n",
|
|
" x_out = torch.cat([x_rotated, x_pass], dim=-1)\n",
|
|
" return x_out.to(dtype=x.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
|
|
"metadata": {
|
|
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class GroupedQueryAttention(nn.Module):\n",
|
|
" def __init__(\n",
|
|
" self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
" assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
|
|
"\n",
|
|
" self.num_heads = num_heads\n",
|
|
" self.num_kv_groups = num_kv_groups\n",
|
|
" self.group_size = num_heads // num_kv_groups\n",
|
|
"\n",
|
|
" if head_dim is None:\n",
|
|
" assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n",
|
|
" head_dim = d_in // num_heads\n",
|
|
"\n",
|
|
" self.head_dim = head_dim\n",
|
|
" self.d_out = num_heads * head_dim\n",
|
|
"\n",
|
|
" # Qwen3.5 full-attention uses a gated Q projection (2x output dim)\n",
|
|
" self.W_query = nn.Linear(d_in, self.d_out * 2, bias=False, dtype=dtype)\n",
|
|
" self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
|
|
" self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
|
|
"\n",
|
|
" self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n",
|
|
"\n",
|
|
" if qk_norm:\n",
|
|
" self.q_norm = RMSNorm(head_dim, eps=1e-6)\n",
|
|
" self.k_norm = RMSNorm(head_dim, eps=1e-6)\n",
|
|
" else:\n",
|
|
" self.q_norm = self.k_norm = None\n",
|
|
"\n",
|
|
" def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n",
|
|
" b, num_tokens, _ = x.shape\n",
|
|
"\n",
|
|
" q_and_gate = self.W_query(x)\n",
|
|
" q_and_gate = q_and_gate.view(b, num_tokens, self.num_heads, self.head_dim * 2)\n",
|
|
" queries, gate = torch.chunk(q_and_gate, 2, dim=-1)\n",
|
|
" gate = gate.reshape(b, num_tokens, self.d_out)\n",
|
|
"\n",
|
|
" keys = self.W_key(x)\n",
|
|
" values = self.W_value(x)\n",
|
|
"\n",
|
|
" queries = queries.transpose(1, 2)\n",
|
|
" keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
|
|
" values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
|
|
"\n",
|
|
" if self.q_norm:\n",
|
|
" queries = self.q_norm(queries)\n",
|
|
" if self.k_norm:\n",
|
|
" keys_new = self.k_norm(keys_new)\n",
|
|
"\n",
|
|
" prev_len = 0\n",
|
|
" if cache is not None:\n",
|
|
" prev_k, prev_v = cache\n",
|
|
" if prev_k is not None:\n",
|
|
" prev_len = prev_k.size(2)\n",
|
|
" keys_cat_raw = torch.cat([prev_k, keys_new], dim=2)\n",
|
|
" values_cat_raw = torch.cat([prev_v, values_new], dim=2)\n",
|
|
" else:\n",
|
|
" keys_cat_raw = keys_new\n",
|
|
" values_cat_raw = values_new\n",
|
|
" else:\n",
|
|
" keys_cat_raw = keys_new\n",
|
|
" values_cat_raw = values_new\n",
|
|
"\n",
|
|
" queries = apply_rope(queries, cos, sin, offset=start_pos)\n",
|
|
" keys = apply_rope(keys_cat_raw, cos, sin, offset=start_pos - prev_len)\n",
|
|
"\n",
|
|
" keys = keys.repeat_interleave(self.group_size, dim=1)\n",
|
|
" values = values_cat_raw.repeat_interleave(self.group_size, dim=1)\n",
|
|
"\n",
|
|
" if cache is not None and cache[0] is not None:\n",
|
|
" next_cache = (\n",
|
|
" torch.cat([cache[0], keys_new], dim=2),\n",
|
|
" torch.cat([cache[1], values_new], dim=2),\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" next_cache = (keys_new, values_new)\n",
|
|
"\n",
|
|
" attn_scores = queries @ keys.transpose(2, 3)\n",
|
|
" attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
|
|
" attn_weights = torch.softmax(\n",
|
|
" attn_scores * (self.head_dim ** -0.5),\n",
|
|
" dim=-1,\n",
|
|
" dtype=torch.float32,\n",
|
|
" ).to(queries.dtype)\n",
|
|
"\n",
|
|
" context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n",
|
|
"\n",
|
|
" # Qwen3.5 full-attention uses a gated Q projection\n",
|
|
" context = context * torch.sigmoid(gate)\n",
|
|
" out = self.out_proj(context)\n",
|
|
" return out, next_cache"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
|
|
"metadata": {
|
|
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from qwen3_5_transformers import (\n",
|
|
" Qwen3_5GatedDeltaNet,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Just a mapping for the different naming convention in Hugging Face transformers\n",
|
|
"class _Qwen3_5ConfigAdapter:\n",
|
|
" def __init__(self, cfg):\n",
|
|
" self.hidden_size = cfg[\"emb_dim\"]\n",
|
|
" self.linear_num_value_heads = cfg[\"linear_num_value_heads\"]\n",
|
|
" self.linear_num_key_heads = cfg[\"linear_num_key_heads\"]\n",
|
|
" self.linear_key_head_dim = cfg[\"linear_key_head_dim\"]\n",
|
|
" self.linear_value_head_dim = cfg[\"linear_value_head_dim\"]\n",
|
|
" self.linear_conv_kernel_dim = cfg[\"linear_conv_kernel_dim\"]\n",
|
|
" self.hidden_act = \"silu\"\n",
|
|
" self.rms_norm_eps = cfg.get(\"rms_norm_eps\", 1e-6)\n",
|
|
" self.dtype = cfg.get(\"dtype\", None)\n",
|
|
"\n",
|
|
"\n",
|
|
"class TransformerBlock(nn.Module):\n",
|
|
" def __init__(self, cfg, layer_type, layer_idx):\n",
|
|
" super().__init__()\n",
|
|
" self.layer_type = layer_type\n",
|
|
"\n",
|
|
" if layer_type == \"full_attention\":\n",
|
|
" self.token_mixer = GroupedQueryAttention(\n",
|
|
" d_in=cfg[\"emb_dim\"],\n",
|
|
" num_heads=cfg[\"n_heads\"],\n",
|
|
" head_dim=cfg[\"head_dim\"],\n",
|
|
" num_kv_groups=cfg[\"n_kv_groups\"],\n",
|
|
" qk_norm=cfg[\"qk_norm\"],\n",
|
|
" dtype=cfg[\"dtype\"],\n",
|
|
" )\n",
|
|
" elif layer_type == \"linear_attention\":\n",
|
|
" self.token_mixer = Qwen3_5GatedDeltaNet(_Qwen3_5ConfigAdapter(cfg), layer_idx)\n",
|
|
" else:\n",
|
|
" raise ValueError(f\"Unsupported layer type: {layer_type}\")\n",
|
|
"\n",
|
|
" self.ff = FeedForward(cfg)\n",
|
|
" self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n",
|
|
" self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n",
|
|
"\n",
|
|
" def forward(self, x, mask, cos, sin, start_pos=0, cache=None, linear_cache=None, cache_position=None):\n",
|
|
" shortcut = x\n",
|
|
" x = self.norm1(x)\n",
|
|
"\n",
|
|
" if self.layer_type == \"full_attention\":\n",
|
|
" x, next_cache = self.token_mixer(\n",
|
|
" x,\n",
|
|
" mask,\n",
|
|
" cos,\n",
|
|
" sin,\n",
|
|
" start_pos=start_pos,\n",
|
|
" cache=cache,\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" x = self.token_mixer(\n",
|
|
" x,\n",
|
|
" cache_params=linear_cache,\n",
|
|
" cache_position=cache_position,\n",
|
|
" )\n",
|
|
" next_cache = None\n",
|
|
"\n",
|
|
" x = x + shortcut\n",
|
|
"\n",
|
|
" shortcut = x\n",
|
|
" x = self.norm2(x)\n",
|
|
" x = self.ff(x)\n",
|
|
" x = x + shortcut\n",
|
|
"\n",
|
|
" return x, next_cache"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
|
"metadata": {
|
|
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Qwen3_5Model(nn.Module):\n",
|
|
" def __init__(self, cfg):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
|
|
"\n",
|
|
" layer_types = cfg.get(\"layer_types\", [\"full_attention\"] * cfg[\"n_layers\"])\n",
|
|
" if len(layer_types) != cfg[\"n_layers\"]:\n",
|
|
" raise ValueError(\"len(layer_types) must equal n_layers\")\n",
|
|
"\n",
|
|
" self.trf_blocks = nn.ModuleList(\n",
|
|
" [TransformerBlock(cfg, layer_type, idx) for idx, layer_type in enumerate(layer_types)]\n",
|
|
" )\n",
|
|
"\n",
|
|
" self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=cfg.get(\"rms_norm_eps\", 1e-6))\n",
|
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
|
"\n",
|
|
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"] if cfg[\"head_dim\"] is None else cfg[\"head_dim\"]\n",
|
|
" cos, sin = compute_rope_params(\n",
|
|
" head_dim=head_dim,\n",
|
|
" theta_base=cfg[\"rope_base\"],\n",
|
|
" context_length=cfg[\"context_length\"],\n",
|
|
" partial_rotary_factor=cfg.get(\"partial_rotary_factor\", 1.0),\n",
|
|
" dtype=torch.float32,\n",
|
|
" )\n",
|
|
" self.register_buffer(\"cos\", cos, persistent=False)\n",
|
|
" self.register_buffer(\"sin\", sin, persistent=False)\n",
|
|
" self.cfg = cfg\n",
|
|
" self.current_pos = 0\n",
|
|
"\n",
|
|
" def create_mask(self, cur_len, device, pos_start=0, pos_end=None):\n",
|
|
" if pos_end is None:\n",
|
|
" pos_end = cur_len\n",
|
|
"\n",
|
|
" ones = torch.ones((pos_end, pos_end), device=device, dtype=torch.bool)\n",
|
|
" mask_full = torch.triu(ones, diagonal=1)\n",
|
|
" row_slice = slice(pos_start, pos_end)\n",
|
|
" mask = mask_full[row_slice, :pos_end][None, None, :, :]\n",
|
|
" return mask\n",
|
|
"\n",
|
|
" def forward(self, in_idx, cache=None):\n",
|
|
" x = self.tok_emb(in_idx)\n",
|
|
"\n",
|
|
" num_tokens = x.shape[1]\n",
|
|
" if cache is not None:\n",
|
|
" pos_start = self.current_pos\n",
|
|
" pos_end = pos_start + num_tokens\n",
|
|
" self.current_pos = pos_end\n",
|
|
" mask = self.create_mask(\n",
|
|
" cur_len=num_tokens,\n",
|
|
" device=x.device,\n",
|
|
" pos_start=pos_start,\n",
|
|
" pos_end=pos_end,\n",
|
|
" )\n",
|
|
" cache_position = torch.arange(pos_start, pos_end, device=x.device, dtype=torch.long)\n",
|
|
" else:\n",
|
|
" pos_start = 0\n",
|
|
" mask = self.create_mask(\n",
|
|
" cur_len=num_tokens,\n",
|
|
" device=x.device,\n",
|
|
" pos_start=0,\n",
|
|
" pos_end=num_tokens,\n",
|
|
" )\n",
|
|
" cache_position = None\n",
|
|
"\n",
|
|
" for i, block in enumerate(self.trf_blocks):\n",
|
|
" blk_cache = cache.get(i) if cache is not None else None\n",
|
|
" x, new_blk_cache = block(\n",
|
|
" x,\n",
|
|
" mask=mask,\n",
|
|
" cos=self.cos,\n",
|
|
" sin=self.sin,\n",
|
|
" start_pos=pos_start,\n",
|
|
" cache=blk_cache,\n",
|
|
" linear_cache=cache.linear_cache if cache is not None else None,\n",
|
|
" cache_position=cache_position,\n",
|
|
" )\n",
|
|
" if cache is not None and new_blk_cache is not None:\n",
|
|
" cache.update(i, new_blk_cache)\n",
|
|
"\n",
|
|
" if cache is not None:\n",
|
|
" cache.linear_cache.has_previous_state = True\n",
|
|
"\n",
|
|
" x = self.final_norm(x)\n",
|
|
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
|
|
" return logits\n",
|
|
"\n",
|
|
" def reset_kv_cache(self):\n",
|
|
" self.current_pos = 0\n",
|
|
"\n",
|
|
"\n",
|
|
"class Qwen3_5LinearAttentionCache:\n",
|
|
" def __init__(self, n_layers):\n",
|
|
" self.conv_states = [None] * n_layers\n",
|
|
" self.recurrent_states = [None] * n_layers\n",
|
|
" self.has_previous_state = False\n",
|
|
"\n",
|
|
" def reset(self):\n",
|
|
" for i in range(len(self.conv_states)):\n",
|
|
" self.conv_states[i] = None\n",
|
|
" self.recurrent_states[i] = None\n",
|
|
" self.has_previous_state = False\n",
|
|
"\n",
|
|
"\n",
|
|
"class KVCache:\n",
|
|
" def __init__(self, n_layers):\n",
|
|
" self.cache = [None] * n_layers\n",
|
|
" self.linear_cache = Qwen3_5LinearAttentionCache(n_layers)\n",
|
|
"\n",
|
|
" def get(self, layer_idx):\n",
|
|
" return self.cache[layer_idx]\n",
|
|
"\n",
|
|
" def update(self, layer_idx, value):\n",
|
|
" self.cache[layer_idx] = value\n",
|
|
"\n",
|
|
" def get_all(self):\n",
|
|
" return self.cache\n",
|
|
"\n",
|
|
" def reset(self):\n",
|
|
" for i in range(len(self.cache)):\n",
|
|
" self.cache[i] = None\n",
|
|
" self.linear_cache.reset()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
|
|
"metadata": {
|
|
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
|
|
},
|
|
"source": [
|
|
" \n",
|
|
"# 2. Initialize model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
|
|
"metadata": {
|
|
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Qwen3.5-0.8B text configuration\n",
|
|
"QWEN3_5_CONFIG = {\n",
|
|
" \"vocab_size\": 248_320,\n",
|
|
" \"context_length\": 262_144,\n",
|
|
" \"emb_dim\": 1_024,\n",
|
|
" \"n_heads\": 8,\n",
|
|
" \"n_layers\": 24,\n",
|
|
" \"hidden_dim\": 3_584,\n",
|
|
" \"head_dim\": 256,\n",
|
|
" \"qk_norm\": True,\n",
|
|
" \"n_kv_groups\": 2,\n",
|
|
" \"rope_base\": 10_000_000.0,\n",
|
|
" \"partial_rotary_factor\": 0.25,\n",
|
|
" \"rms_norm_eps\": 1e-6,\n",
|
|
" \"linear_conv_kernel_dim\": 4,\n",
|
|
" \"linear_key_head_dim\": 128,\n",
|
|
" \"linear_value_head_dim\": 128,\n",
|
|
" \"linear_num_key_heads\": 16,\n",
|
|
" \"linear_num_value_heads\": 16,\n",
|
|
" \"dtype\": torch.bfloat16,\n",
|
|
" \"layer_types\": [\n",
|
|
" \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n",
|
|
" \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n",
|
|
" \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n",
|
|
" \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n",
|
|
" \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n",
|
|
" \"linear_attention\", \"linear_attention\", \"linear_attention\", \"full_attention\",\n",
|
|
" ],\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
|
"metadata": {
|
|
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"torch.manual_seed(123)\n",
|
|
"model = Qwen3_5Model(QWEN3_5_CONFIG)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Qwen3_5Model(\n",
|
|
" (tok_emb): Embedding(248320, 1024)\n",
|
|
" (trf_blocks): ModuleList(\n",
|
|
" (0-2): 3 x TransformerBlock(\n",
|
|
" (token_mixer): Qwen3_5GatedDeltaNet(\n",
|
|
" (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n",
|
|
" (norm): Qwen3_5RMSNormGated()\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n",
|
|
" (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n",
|
|
" (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (3): TransformerBlock(\n",
|
|
" (token_mixer): GroupedQueryAttention(\n",
|
|
" (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n",
|
|
" (W_key): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (W_value): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (q_norm): RMSNorm()\n",
|
|
" (k_norm): RMSNorm()\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (4-6): 3 x TransformerBlock(\n",
|
|
" (token_mixer): Qwen3_5GatedDeltaNet(\n",
|
|
" (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n",
|
|
" (norm): Qwen3_5RMSNormGated()\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n",
|
|
" (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n",
|
|
" (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (7): TransformerBlock(\n",
|
|
" (token_mixer): GroupedQueryAttention(\n",
|
|
" (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n",
|
|
" (W_key): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (W_value): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (q_norm): RMSNorm()\n",
|
|
" (k_norm): RMSNorm()\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (8-10): 3 x TransformerBlock(\n",
|
|
" (token_mixer): Qwen3_5GatedDeltaNet(\n",
|
|
" (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n",
|
|
" (norm): Qwen3_5RMSNormGated()\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n",
|
|
" (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n",
|
|
" (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (11): TransformerBlock(\n",
|
|
" (token_mixer): GroupedQueryAttention(\n",
|
|
" (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n",
|
|
" (W_key): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (W_value): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (q_norm): RMSNorm()\n",
|
|
" (k_norm): RMSNorm()\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (12-14): 3 x TransformerBlock(\n",
|
|
" (token_mixer): Qwen3_5GatedDeltaNet(\n",
|
|
" (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n",
|
|
" (norm): Qwen3_5RMSNormGated()\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n",
|
|
" (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n",
|
|
" (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (15): TransformerBlock(\n",
|
|
" (token_mixer): GroupedQueryAttention(\n",
|
|
" (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n",
|
|
" (W_key): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (W_value): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (q_norm): RMSNorm()\n",
|
|
" (k_norm): RMSNorm()\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (16-18): 3 x TransformerBlock(\n",
|
|
" (token_mixer): Qwen3_5GatedDeltaNet(\n",
|
|
" (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n",
|
|
" (norm): Qwen3_5RMSNormGated()\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n",
|
|
" (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n",
|
|
" (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (19): TransformerBlock(\n",
|
|
" (token_mixer): GroupedQueryAttention(\n",
|
|
" (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n",
|
|
" (W_key): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (W_value): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (q_norm): RMSNorm()\n",
|
|
" (k_norm): RMSNorm()\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (20-22): 3 x TransformerBlock(\n",
|
|
" (token_mixer): Qwen3_5GatedDeltaNet(\n",
|
|
" (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)\n",
|
|
" (norm): Qwen3_5RMSNormGated()\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (in_proj_qkv): Linear(in_features=1024, out_features=6144, bias=False)\n",
|
|
" (in_proj_z): Linear(in_features=1024, out_features=2048, bias=False)\n",
|
|
" (in_proj_b): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" (in_proj_a): Linear(in_features=1024, out_features=16, bias=False)\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" (23): TransformerBlock(\n",
|
|
" (token_mixer): GroupedQueryAttention(\n",
|
|
" (W_query): Linear(in_features=1024, out_features=4096, bias=False)\n",
|
|
" (W_key): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (W_value): Linear(in_features=1024, out_features=512, bias=False)\n",
|
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|
" (q_norm): RMSNorm()\n",
|
|
" (k_norm): RMSNorm()\n",
|
|
" )\n",
|
|
" (ff): FeedForward(\n",
|
|
" (fc1): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc2): Linear(in_features=1024, out_features=3584, bias=False)\n",
|
|
" (fc3): Linear(in_features=3584, out_features=1024, bias=False)\n",
|
|
" )\n",
|
|
" (norm1): RMSNorm()\n",
|
|
" (norm2): RMSNorm()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (final_norm): RMSNorm()\n",
|
|
" (out_head): Linear(in_features=1024, out_features=248320, bias=False)\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "90aca91d-4bee-45ce-993a-4ec5393abe2b",
|
|
"metadata": {},
|
|
"source": [
|
|
"- A quick check that the forward pass works before continuing:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[[-0.6719, -0.0347, -0.5938, ..., 0.5469, 0.1660, -0.8945],\n",
|
|
" [ 0.0391, -0.1226, -0.8789, ..., -0.6523, -0.8281, -0.0889],\n",
|
|
" [ 0.1992, -0.7930, -0.3359, ..., -0.6602, 0.0515, -0.1582]]],\n",
|
|
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model(torch.tensor([1, 2, 3]).unsqueeze(0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
|
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Total number of parameters: 1,006,672,704\n",
|
|
"\n",
|
|
"Total number of unique parameters: 752,393,024\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"total_params = sum(p.numel() for p in model.parameters())\n",
|
|
"print(f\"Total number of parameters: {total_params:,}\")\n",
|
|
"\n",
|
|
"# Account for weight tying\n",
|
|
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
|
|
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
|
"outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"float32 (PyTorch default): 7.63 GB\n",
|
|
"bfloat16: 3.81 GB\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
|
|
" total_params = 0\n",
|
|
" total_grads = 0\n",
|
|
" for param in model.parameters():\n",
|
|
" # Calculate total number of elements per parameter\n",
|
|
" param_size = param.numel()\n",
|
|
" total_params += param_size\n",
|
|
" # Check if gradients are stored for this parameter\n",
|
|
" if param.requires_grad:\n",
|
|
" total_grads += param_size\n",
|
|
"\n",
|
|
" # Calculate buffer size (non-parameters that require memory)\n",
|
|
" total_buffers = sum(buf.numel() for buf in model.buffers())\n",
|
|
"\n",
|
|
" # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
|
|
" # We assume parameters and gradients are stored in the same type as input dtype\n",
|
|
" element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
|
|
" total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
|
|
"\n",
|
|
" # Convert bytes to gigabytes\n",
|
|
" total_memory_gb = total_memory_bytes / (1024**3)\n",
|
|
"\n",
|
|
" return total_memory_gb\n",
|
|
"\n",
|
|
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
|
|
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
|
"metadata": {
|
|
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"if torch.cuda.is_available():\n",
|
|
" device = torch.device(\"cuda\")\n",
|
|
"elif torch.backends.mps.is_available():\n",
|
|
" device = torch.device(\"mps\")\n",
|
|
"else:\n",
|
|
" device = torch.device(\"cpu\")\n",
|
|
"\n",
|
|
"model.to(device);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c172f89f-d301-439f-b809-46169e5f5945",
|
|
"metadata": {
|
|
"id": "c172f89f-d301-439f-b809-46169e5f5945"
|
|
},
|
|
"source": [
|
|
" \n",
|
|
"# 3. Load pretrained weights"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "75166128-5899-4995-9b88-9672e135650e",
|
|
"metadata": {
|
|
"id": "75166128-5899-4995-9b88-9672e135650e"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_weights_into_qwen3_5(model, param_config, params):\n",
|
|
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
|
" if left.shape != right.shape:\n",
|
|
" raise ValueError(\n",
|
|
" f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" with torch.no_grad():\n",
|
|
" if isinstance(right, torch.Tensor):\n",
|
|
" left.copy_(right)\n",
|
|
" else:\n",
|
|
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
|
|
"\n",
|
|
" return left\n",
|
|
"\n",
|
|
" if \"model.embed_tokens.weight\" in params:\n",
|
|
" model_prefix = \"model\"\n",
|
|
" elif \"model.language_model.embed_tokens.weight\" in params:\n",
|
|
" model_prefix = \"model.language_model\"\n",
|
|
" else:\n",
|
|
" raise KeyError(\"Could not find embed token weights in checkpoint.\")\n",
|
|
"\n",
|
|
" def pkey(suffix):\n",
|
|
" return f\"{model_prefix}.{suffix}\"\n",
|
|
"\n",
|
|
" model.tok_emb.weight = assign(\n",
|
|
" model.tok_emb.weight,\n",
|
|
" params[pkey(\"embed_tokens.weight\")],\n",
|
|
" pkey(\"embed_tokens.weight\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" n_layers = param_config[\"n_layers\"]\n",
|
|
" layer_types = param_config.get(\"layer_types\", [\"full_attention\"] * n_layers)\n",
|
|
"\n",
|
|
" for l in range(n_layers):\n",
|
|
" block = model.trf_blocks[l]\n",
|
|
" layer_type = layer_types[l]\n",
|
|
"\n",
|
|
" if layer_type == \"full_attention\":\n",
|
|
" att = block.token_mixer\n",
|
|
" att.W_query.weight = assign(\n",
|
|
" att.W_query.weight,\n",
|
|
" params[pkey(f\"layers.{l}.self_attn.q_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.self_attn.q_proj.weight\"),\n",
|
|
" )\n",
|
|
" att.W_key.weight = assign(\n",
|
|
" att.W_key.weight,\n",
|
|
" params[pkey(f\"layers.{l}.self_attn.k_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.self_attn.k_proj.weight\"),\n",
|
|
" )\n",
|
|
" att.W_value.weight = assign(\n",
|
|
" att.W_value.weight,\n",
|
|
" params[pkey(f\"layers.{l}.self_attn.v_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.self_attn.v_proj.weight\"),\n",
|
|
" )\n",
|
|
" att.out_proj.weight = assign(\n",
|
|
" att.out_proj.weight,\n",
|
|
" params[pkey(f\"layers.{l}.self_attn.o_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.self_attn.o_proj.weight\"),\n",
|
|
" )\n",
|
|
" if hasattr(att, \"q_norm\") and att.q_norm is not None:\n",
|
|
" att.q_norm.weight = assign(\n",
|
|
" att.q_norm.weight,\n",
|
|
" params[pkey(f\"layers.{l}.self_attn.q_norm.weight\")],\n",
|
|
" pkey(f\"layers.{l}.self_attn.q_norm.weight\"),\n",
|
|
" )\n",
|
|
" if hasattr(att, \"k_norm\") and att.k_norm is not None:\n",
|
|
" att.k_norm.weight = assign(\n",
|
|
" att.k_norm.weight,\n",
|
|
" params[pkey(f\"layers.{l}.self_attn.k_norm.weight\")],\n",
|
|
" pkey(f\"layers.{l}.self_attn.k_norm.weight\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" elif layer_type == \"linear_attention\":\n",
|
|
" lat = block.token_mixer\n",
|
|
" lat.dt_bias = assign(\n",
|
|
" lat.dt_bias,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.dt_bias\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.dt_bias\"),\n",
|
|
" )\n",
|
|
" lat.A_log = assign(\n",
|
|
" lat.A_log,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.A_log\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.A_log\"),\n",
|
|
" )\n",
|
|
" lat.conv1d.weight = assign(\n",
|
|
" lat.conv1d.weight,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.conv1d.weight\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.conv1d.weight\"),\n",
|
|
" )\n",
|
|
" lat.norm.weight = assign(\n",
|
|
" lat.norm.weight,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.norm.weight\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.norm.weight\"),\n",
|
|
" )\n",
|
|
" lat.out_proj.weight = assign(\n",
|
|
" lat.out_proj.weight,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.out_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.out_proj.weight\"),\n",
|
|
" )\n",
|
|
" lat.in_proj_qkv.weight = assign(\n",
|
|
" lat.in_proj_qkv.weight,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.in_proj_qkv.weight\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.in_proj_qkv.weight\"),\n",
|
|
" )\n",
|
|
" lat.in_proj_z.weight = assign(\n",
|
|
" lat.in_proj_z.weight,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.in_proj_z.weight\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.in_proj_z.weight\"),\n",
|
|
" )\n",
|
|
" lat.in_proj_b.weight = assign(\n",
|
|
" lat.in_proj_b.weight,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.in_proj_b.weight\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.in_proj_b.weight\"),\n",
|
|
" )\n",
|
|
" lat.in_proj_a.weight = assign(\n",
|
|
" lat.in_proj_a.weight,\n",
|
|
" params[pkey(f\"layers.{l}.linear_attn.in_proj_a.weight\")],\n",
|
|
" pkey(f\"layers.{l}.linear_attn.in_proj_a.weight\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" else:\n",
|
|
" raise ValueError(f\"Unsupported layer type: {layer_type}\")\n",
|
|
"\n",
|
|
" block.norm1.weight = assign(\n",
|
|
" block.norm1.weight,\n",
|
|
" params[pkey(f\"layers.{l}.input_layernorm.weight\")],\n",
|
|
" pkey(f\"layers.{l}.input_layernorm.weight\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" block.ff.fc1.weight = assign(\n",
|
|
" block.ff.fc1.weight,\n",
|
|
" params[pkey(f\"layers.{l}.mlp.gate_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.mlp.gate_proj.weight\"),\n",
|
|
" )\n",
|
|
" block.ff.fc2.weight = assign(\n",
|
|
" block.ff.fc2.weight,\n",
|
|
" params[pkey(f\"layers.{l}.mlp.up_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.mlp.up_proj.weight\"),\n",
|
|
" )\n",
|
|
" block.ff.fc3.weight = assign(\n",
|
|
" block.ff.fc3.weight,\n",
|
|
" params[pkey(f\"layers.{l}.mlp.down_proj.weight\")],\n",
|
|
" pkey(f\"layers.{l}.mlp.down_proj.weight\"),\n",
|
|
" )\n",
|
|
" block.norm2.weight = assign(\n",
|
|
" block.norm2.weight,\n",
|
|
" params[pkey(f\"layers.{l}.post_attention_layernorm.weight\")],\n",
|
|
" pkey(f\"layers.{l}.post_attention_layernorm.weight\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" model.final_norm.weight = assign(\n",
|
|
" model.final_norm.weight,\n",
|
|
" params[pkey(\"norm.weight\")],\n",
|
|
" pkey(\"norm.weight\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" if \"lm_head.weight\" in params:\n",
|
|
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
|
" elif pkey(\"lm_head.weight\") in params:\n",
|
|
" model.out_head.weight = assign(model.out_head.weight, params[pkey(\"lm_head.weight\")], pkey(\"lm_head.weight\"))\n",
|
|
" else:\n",
|
|
" model.out_head.weight = model.tok_emb.weight\n",
|
|
" print(\"Model uses weight tying.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 17,
|
|
"referenced_widgets": [
|
|
"9881b6995c3f49dc89e6992fd9ab660b",
|
|
"17a3174e65c54476b2e0d1faf8f011ca",
|
|
"1bbf2e62c0754d1593beb4105a7f1ac1",
|
|
"b82112e1dec645d98aa1c1ba64abcb61",
|
|
"271e2bd6a35e4a8b92de8697f7c0be5f",
|
|
"90a79523187446dfa692723b2e5833a7",
|
|
"431ffb83b8c14bf182f0430e07ea6154",
|
|
"a8f1b72a33dd4b548de23fbd95e0da18",
|
|
"25cc36132d384189acfbecc59483134b",
|
|
"bfd06423ad544218968648016e731a46",
|
|
"d029630b63ff44cf807ade428d2eb421"
|
|
]
|
|
},
|
|
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
|
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "6ca01175c472450786e4ae0201a39beb",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading (incomplete total...): 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "348a2193fba34101a79816dc808e8533",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Fetching 13 files: 0%| | 0/13 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Model uses weight tying.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import json\n",
|
|
"import os\n",
|
|
"from pathlib import Path\n",
|
|
"from safetensors.torch import load_file\n",
|
|
"from huggingface_hub import hf_hub_download, snapshot_download\n",
|
|
"\n",
|
|
"repo_id = \"Qwen/Qwen3.5-0.8B\"\n",
|
|
"local_dir = Path(repo_id).parts[-1]\n",
|
|
"\n",
|
|
"repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)\n",
|
|
"index_path = os.path.join(repo_dir, \"model.safetensors.index.json\")\n",
|
|
"with open(index_path, \"r\") as f:\n",
|
|
" index = json.load(f)\n",
|
|
"\n",
|
|
"weights_dict = {}\n",
|
|
"for filename in sorted(set(index[\"weight_map\"].values())):\n",
|
|
" shard_path = os.path.join(repo_dir, filename)\n",
|
|
" shard = load_file(shard_path)\n",
|
|
" weights_dict.update(shard)\n",
|
|
"\n",
|
|
"load_weights_into_qwen3_5(model, QWEN3_5_CONFIG, weights_dict)\n",
|
|
"model.to(device)\n",
|
|
"del weights_dict"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6b345491-3510-4397-92d3-cd0a3fa3deee",
|
|
"metadata": {},
|
|
"source": [
|
|
" \n",
|
|
"# 4. Load tokenizer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import re\n",
|
|
"from tokenizers import Tokenizer\n",
|
|
"\n",
|
|
"\n",
|
|
"class Qwen3_5Tokenizer:\n",
|
|
" _SPECIALS = [\n",
|
|
" \"<|endoftext|>\",\n",
|
|
" \"<|im_start|>\", \"<|im_end|>\",\n",
|
|
" \"<|object_ref_start|>\", \"<|object_ref_end|>\",\n",
|
|
" \"<|box_start|>\", \"<|box_end|>\",\n",
|
|
" \"<|quad_start|>\", \"<|quad_end|>\",\n",
|
|
" \"<|vision_start|>\", \"<|vision_end|>\",\n",
|
|
" \"<|vision_pad|>\", \"<|image_pad|>\", \"<|video_pad|>\",\n",
|
|
" \"<think>\", \"</think>\",\n",
|
|
" ]\n",
|
|
" _SPLIT_RE = re.compile(r\"(<\\|[^>]+?\\|>|<think>|</think>)\")\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" tokenizer_file_path=\"tokenizer.json\",\n",
|
|
" repo_id=None,\n",
|
|
" apply_chat_template=True,\n",
|
|
" add_generation_prompt=False,\n",
|
|
" add_thinking=False,\n",
|
|
" ):\n",
|
|
" self.apply_chat_template = apply_chat_template\n",
|
|
" self.add_generation_prompt = add_generation_prompt\n",
|
|
" self.add_thinking = add_thinking\n",
|
|
"\n",
|
|
" tok_file = Path(tokenizer_file_path)\n",
|
|
" self._tok = Tokenizer.from_file(str(tok_file))\n",
|
|
" self._special_to_id = {}\n",
|
|
" for t in self._SPECIALS:\n",
|
|
" tid = self._tok.token_to_id(t)\n",
|
|
" if tid is not None:\n",
|
|
" self._special_to_id[t] = tid\n",
|
|
"\n",
|
|
" self.pad_token_id = self._special_to_id[\"<|endoftext|>\"]\n",
|
|
" self.eos_token_id = self.pad_token_id\n",
|
|
"\n",
|
|
" if repo_id and \"Base\" not in repo_id:\n",
|
|
" eos_token = \"<|im_end|>\"\n",
|
|
" else:\n",
|
|
" eos_token = \"<|endoftext|>\"\n",
|
|
" if eos_token in self._special_to_id:\n",
|
|
" self.eos_token_id = self._special_to_id[eos_token]\n",
|
|
"\n",
|
|
" def encode(self, text, chat_wrapped=None):\n",
|
|
" if chat_wrapped is None:\n",
|
|
" chat_wrapped = self.apply_chat_template\n",
|
|
"\n",
|
|
" stripped = text.strip()\n",
|
|
" if stripped in self._special_to_id and \"\\n\" not in stripped:\n",
|
|
" return [self._special_to_id[stripped]]\n",
|
|
"\n",
|
|
" if chat_wrapped:\n",
|
|
" text = self._wrap_chat(text)\n",
|
|
"\n",
|
|
" ids = []\n",
|
|
" for part in filter(None, self._SPLIT_RE.split(text)):\n",
|
|
" if part in self._special_to_id:\n",
|
|
" ids.append(self._special_to_id[part])\n",
|
|
" else:\n",
|
|
" ids.extend(self._tok.encode(part).ids)\n",
|
|
" return ids\n",
|
|
"\n",
|
|
" def decode(self, ids):\n",
|
|
" return self._tok.decode(ids, skip_special_tokens=False)\n",
|
|
"\n",
|
|
" def _wrap_chat(self, user_msg):\n",
|
|
" # Mirrors Qwen3.5 chat_template behavior:\n",
|
|
" # add_generation_prompt + thinking => \"<think>\\n\"\n",
|
|
" # add_generation_prompt + no thinking => empty think scaffold\n",
|
|
" s = f\"<|im_start|>user\\n{user_msg}<|im_end|>\\n\"\n",
|
|
" if self.add_generation_prompt:\n",
|
|
" s += \"<|im_start|>assistant\\n\"\n",
|
|
" if self.add_thinking:\n",
|
|
" s += \"<think>\\n\"\n",
|
|
" else:\n",
|
|
" s += \"<think>\\n\\n</think>\\n\\n\"\n",
|
|
" return s\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer_file_path = \"Qwen3.5-0.8B/tokenizer.json\"\n",
|
|
"\n",
|
|
"hf_hub_download(\n",
|
|
" repo_id=repo_id,\n",
|
|
" filename=\"tokenizer.json\",\n",
|
|
" local_dir=local_dir,\n",
|
|
")\n",
|
|
"\n",
|
|
"tokenizer = Qwen3_5Tokenizer(\n",
|
|
" tokenizer_file_path=tokenizer_file_path,\n",
|
|
" repo_id=repo_id,\n",
|
|
" apply_chat_template=True,\n",
|
|
" add_generation_prompt=True,\n",
|
|
" add_thinking=True,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "1946b534-e3af-431a-a222-391a60bfa892",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n<think>\\n'"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"prompt = \"Give me a short introduction to large language models.\"\n",
|
|
"\n",
|
|
"input_token_ids = tokenizer.encode(prompt)\n",
|
|
"text = tokenizer.decode(input_token_ids)\n",
|
|
"text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
|
|
"metadata": {
|
|
"id": "57d07df1-4401-4792-b549-7c4cc5632323"
|
|
},
|
|
"source": [
|
|
" \n",
|
|
"# 4. Generate text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
|
"metadata": {
|
|
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None):\n",
|
|
"\n",
|
|
" model.eval()\n",
|
|
" with torch.no_grad():\n",
|
|
" cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n",
|
|
" model.reset_kv_cache()\n",
|
|
"\n",
|
|
" # Prime the cache with the initial context\n",
|
|
" logits = model(token_ids, cache=cache)\n",
|
|
"\n",
|
|
" for _ in range(max_new_tokens):\n",
|
|
" next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
|
|
"\n",
|
|
" if eos_token_id is not None and torch.all(next_token == eos_token_id):\n",
|
|
" break\n",
|
|
"\n",
|
|
" yield next_token\n",
|
|
"\n",
|
|
" token_ids = torch.cat([token_ids, next_token], dim=1)\n",
|
|
" # Feed only the new token to the model; cache handles history\n",
|
|
" logits = model(next_token, cache=cache)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
|
"metadata": {
|
|
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Thinking Process:\n",
|
|
"\n",
|
|
"1. **Analyze the Request:**\n",
|
|
" * **Topic:** Large Language Models (LLMs).\n",
|
|
" * **Task:** Give a short introduction.\n",
|
|
" * **Constraint:** \"Short\" (implies concise, clear, and impactful).\n",
|
|
"\n",
|
|
"2. **Identify Key Concepts:**\n",
|
|
" * What are they? (AI models trained on massive datasets).\n",
|
|
" * What do they do? (Generate text, code, etc.).\n",
|
|
" * How do they work? (Neural networks, transformers, training).\n",
|
|
" * Why are they important? (Efficiency, context, creativity).\n",
|
|
" * *Self-Correction/Refinement:* Keep it simple but accurate. Avoid overly technical jargon unless necessary, but \"transformers\" is a key term.\n",
|
|
"\n",
|
|
"3. **Drafting - Attempt 1 (Mental Outline):**\n",
|
|
" LLMs are big AI models. They are trained on huge amounts of data. They can understand and generate text. They are like a supercomputer for language. They are used in chatbots and coding.\n",
|
|
"\n",
|
|
"4. **Drafting - Attempt 2 (Adding Detail & Flow):**\n",
|
|
" Large Language Models (LLMs) are a type of artificial intelligence. They are trained on massive datasets of text. They use neural networks to understand and generate human-like text. They are used in chatbots, coding assistants, and creative writing. They are becoming more powerful and efficient.\n",
|
|
"\n",
|
|
"5. **Drafting - Attempt 3 (Polishing for \"Short Introduction\"):**\n",
|
|
" Large Language Models (LLMs) are a type of artificial intelligence that can understand and generate human-like text. They are trained on massive datasets of text. They use neural networks to process information and create content. They are used in chatbots, coding assistants, and creative writing. They are becoming more powerful and efficient.\n",
|
|
"\n",
|
|
"6. **Refining for Clarity and Impact:**\n",
|
|
" * Make it punchy.\n",
|
|
" * Highlight the \"transformers\" or \"neural networks\" aspect if needed, but keep it simple.\n",
|
|
" * Mention the \"big data\" aspect.\n",
|
|
"\n",
|
|
"7. **Final Polish (incorporating into the final output):**\n",
|
|
" * Start with a definition.\n",
|
|
" * Mention the core technology (neural networks).\n",
|
|
" * Mention the output\n",
|
|
"\n",
|
|
"Generation speed: 8.25 tokens/sec\n",
|
|
"GPU memory used: 2.54 GB\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import time\n",
|
|
"\n",
|
|
"prompt = \"Give me a short introduction to large language models.\"\n",
|
|
"\n",
|
|
"input_token_ids = tokenizer.encode(prompt)\n",
|
|
"input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n",
|
|
"\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" torch.cuda.reset_peak_memory_stats()\n",
|
|
"\n",
|
|
"start_time = time.perf_counter()\n",
|
|
"generated_tokens = 0\n",
|
|
"\n",
|
|
"for token in generate_text_basic_stream(\n",
|
|
" model=model,\n",
|
|
" token_ids=input_token_ids_tensor,\n",
|
|
" max_new_tokens=500,\n",
|
|
" eos_token_id=tokenizer.eos_token_id\n",
|
|
"):\n",
|
|
" generated_tokens += 1\n",
|
|
" token_id = token.squeeze(0).tolist()\n",
|
|
" print(\n",
|
|
" tokenizer.decode(token_id),\n",
|
|
" end=\"\",\n",
|
|
" flush=True\n",
|
|
" )\n",
|
|
"\n",
|
|
"elapsed = time.perf_counter() - start_time\n",
|
|
"tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n",
|
|
"print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n",
|
|
"\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" def calc_gpu_gb(x):\n",
|
|
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
|
|
"\n",
|
|
" print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "b0ef78d8-e512-47c2-aaab-d236a6e7cad3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Here's"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" a thinking process that leads to the solution:\n",
|
|
"\n",
|
|
"1. **Analyze the Request:**\n",
|
|
" * **Scenario:** A shop applies two discounts and a tax.\n",
|
|
" * **Discount:** 20% off the original price.\n",
|
|
" * **Tax:** 10% added on top of the discounted price.\n",
|
|
" * **Question:** Is the final price higher or lower than the original? By how much?\n",
|
|
"\n",
|
|
"2. **Define Variables:**\n",
|
|
" * Let $P$ be the original price.\n",
|
|
"\n",
|
|
"3. **Step-by-Step Calculation:**\n",
|
|
"\n",
|
|
" * *Step 1: Apply the 20% discount.*\n",
|
|
" * Discount amount = $0.20 \\times P$\n",
|
|
" * Final price after discount = $P - 0.20P$\n",
|
|
" * Final price after discount = $0.80P$\n",
|
|
"\n",
|
|
" * *Step 2: Apply the 10% tax.*\n",
|
|
" * Tax amount = $0.10 \\times (\\text{Final price after discount})$\n",
|
|
" * Tax amount = $0.10 \\times (0.80P)$\n",
|
|
" * Tax amount = $0.08P$\n",
|
|
" * Final price after tax = Final price after discount + Tax amount\n",
|
|
" * Final price after tax = $0.80P + 0.08P$\n",
|
|
" * Final price after tax = $0.88P$\n",
|
|
"\n",
|
|
" * *Step 3: Compare Final Price to Original Price.*\n",
|
|
" * Original Price = $P$\n",
|
|
" * Final Price = $0.88P$\n",
|
|
" * Since $0.88 < 1$, the final price is lower.\n",
|
|
"\n",
|
|
" * *Step 4: Calculate the difference.*\n",
|
|
" * Difference = Final Price - Original Price\n",
|
|
" * Difference = $0.88P - P$\n",
|
|
" * Difference = $-0.12P$\n",
|
|
" * The difference is $0.12P$ (or 12% of the original price).\n",
|
|
"\n",
|
|
"4. **Verification:**\n",
|
|
" * Let's pick a specific number to make sure.\n",
|
|
" * Let $P = 100$.\n",
|
|
" * \n",
|
|
"\n",
|
|
"Generation speed: 9.00 tokens/sec\n",
|
|
"GPU memory used: 2.56 GB\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import time\n",
|
|
"\n",
|
|
"prompt = \"A shop gives a 20% discount, then adds 10% tax. Is the final price higher or lower than the original? By how much?\"\n",
|
|
"\n",
|
|
"input_token_ids = tokenizer.encode(prompt)\n",
|
|
"input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n",
|
|
"\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" torch.cuda.reset_peak_memory_stats()\n",
|
|
"\n",
|
|
"start_time = time.perf_counter()\n",
|
|
"generated_tokens = 0\n",
|
|
"\n",
|
|
"for token in generate_text_basic_stream(\n",
|
|
" model=model,\n",
|
|
" token_ids=input_token_ids_tensor,\n",
|
|
" max_new_tokens=500,\n",
|
|
" eos_token_id=tokenizer.eos_token_id\n",
|
|
"):\n",
|
|
" generated_tokens += 1\n",
|
|
" token_id = token.squeeze(0).tolist()\n",
|
|
" print(\n",
|
|
" tokenizer.decode(token_id),\n",
|
|
" end=\"\",\n",
|
|
" flush=True\n",
|
|
" )\n",
|
|
"\n",
|
|
"elapsed = time.perf_counter() - start_time\n",
|
|
"tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n",
|
|
"print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n",
|
|
"\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" def calc_gpu_gb(x):\n",
|
|
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
|
|
"\n",
|
|
" print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "549324d6-5c71-4147-ae21-2e67675faa3d",
|
|
"metadata": {
|
|
"id": "549324d6-5c71-4147-ae21-2e67675faa3d"
|
|
},
|
|
"source": [
|
|
" \n",
|
|
"# What's next?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
|
|
"metadata": {
|
|
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
|
|
},
|
|
"source": [
|
|
"- Check out the [README.md](../11_qwen3/README.md), to use this model via the `llms_from_scratch` package\n",
|
|
"- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
|
|
"\n",
|
|
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "A100",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|