mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Readability and code quality improvements (#959)
* Consistent dataset naming * consistent section headers
This commit is contained in:
committed by
GitHub
parent
7b1f740f74
commit
be5e2a3331
@@ -85,6 +85,7 @@
|
||||
"id": "ecc4dcee-34ea-4c05-9085-2f8887f70363",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 3.1 The problem with modeling long sequences"
|
||||
]
|
||||
},
|
||||
@@ -127,6 +128,7 @@
|
||||
"id": "3602c585-b87a-41c7-a324-c5e8298849df",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 3.2 Capturing data dependencies with attention mechanisms"
|
||||
]
|
||||
},
|
||||
@@ -168,6 +170,7 @@
|
||||
"id": "5efe05ff-b441-408e-8d66-cde4eb3397e3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 3.3 Attending to different parts of the input with self-attention"
|
||||
]
|
||||
},
|
||||
@@ -176,6 +179,7 @@
|
||||
"id": "6d9af516-7c37-4400-ab53-34936d5495a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.3.1 A simple self-attention mechanism without trainable weights"
|
||||
]
|
||||
},
|
||||
@@ -216,7 +220,7 @@
|
||||
"id": "ff856c58-8382-44c7-827f-798040e6e697",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- By convention, the unnormalized attention weights are referred to as **\"attention scores\"** whereas the normalized attention scores, which sum to 1, are referred to as **\"attention weights\"**\n"
|
||||
"- By convention, the unnormalized attention weights are referred to as **\"attention scores\"** whereas the normalized attention scores, which sum to 1, are referred to as **\"attention weights\"**"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -503,6 +507,7 @@
|
||||
"id": "5a454262-40eb-430e-9ca4-e43fb8d6cd89",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.3.2 Computing attention weights for all input tokens"
|
||||
]
|
||||
},
|
||||
@@ -739,6 +744,7 @@
|
||||
"id": "a303b6fb-9f7e-42bb-9fdb-2adabf0a6525",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 3.4 Implementing self-attention with trainable weights"
|
||||
]
|
||||
},
|
||||
@@ -763,6 +769,7 @@
|
||||
"id": "2b90a77e-d746-4704-9354-1ddad86e6298",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.4.1 Computing the attention weights step by step"
|
||||
]
|
||||
},
|
||||
@@ -1046,6 +1053,7 @@
|
||||
"id": "9d7b2907-e448-473e-b46c-77735a7281d8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.4.2 Implementing a compact SelfAttention class"
|
||||
]
|
||||
},
|
||||
@@ -1179,6 +1187,7 @@
|
||||
"id": "c5025b37-0f2c-4a67-a7cb-1286af7026ab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 3.5 Hiding future words with causal attention"
|
||||
]
|
||||
},
|
||||
@@ -1203,6 +1212,7 @@
|
||||
"id": "82f405de-cd86-4e72-8f3c-9ea0354946ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.5.1 Applying a causal attention mask"
|
||||
]
|
||||
},
|
||||
@@ -1455,6 +1465,7 @@
|
||||
"id": "7636fc5f-6bc6-461e-ac6a-99ec8e3c0912",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.5.2 Masking additional attention weights with dropout"
|
||||
]
|
||||
},
|
||||
@@ -1554,6 +1565,7 @@
|
||||
"id": "cdc14639-5f0f-4840-aa9d-8eb36ea90fb7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.5.3 Implementing a compact causal self-attention class"
|
||||
]
|
||||
},
|
||||
@@ -1679,6 +1691,7 @@
|
||||
"id": "c8bef90f-cfd4-4289-b0e8-6a00dc9be44c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## 3.6 Extending single-head attention to multi-head attention"
|
||||
]
|
||||
},
|
||||
@@ -1687,6 +1700,7 @@
|
||||
"id": "11697757-9198-4a1c-9cee-f450d8bbd3b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.6.1 Stacking multiple single-head attention layers"
|
||||
]
|
||||
},
|
||||
@@ -1776,6 +1790,7 @@
|
||||
"id": "6836b5da-ef82-4b4c-bda1-72a462e48d4e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"### 3.6.2 Implementing multi-head attention with weight splits"
|
||||
]
|
||||
},
|
||||
@@ -2032,7 +2047,8 @@
|
||||
"id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Summary and takeaways"
|
||||
" \n",
|
||||
"## Summary and takeaways"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -2061,7 +2077,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -54,7 +54,8 @@
|
||||
"id": "33dfa199-9aee-41d4-a64b-7e3811b9a616",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 3.1"
|
||||
" \n",
|
||||
"## Exercise 3.1"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -209,7 +210,8 @@
|
||||
"id": "33543edb-46b5-4b01-8704-f7f101230544",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 3.2"
|
||||
" \n",
|
||||
"## Exercise 3.2"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -266,7 +268,8 @@
|
||||
"id": "92bdabcb-06cf-4576-b810-d883bbd313ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 3.3"
|
||||
" \n",
|
||||
"## Exercise 3.3"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -339,7 +342,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -117,7 +117,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 1) CausalAttention MHA wrapper class from chapter 3"
|
||||
"## 1. CausalAttention MHA wrapper class from chapter 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -208,7 +208,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 2) The multi-head attention class from chapter 3"
|
||||
"## 2. The multi-head attention class from chapter 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -311,7 +311,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 3) An alternative multi-head attention with combined weights"
|
||||
"## 3. An alternative multi-head attention with combined weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -435,7 +435,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 4) Multi-head attention with Einsum\n",
|
||||
"## 4. Multi-head attention with Einsum\n",
|
||||
"\n",
|
||||
"- Implementing multi-head attention using Einstein summation via [`torch.einsum`](https://pytorch.org/docs/stable/generated/torch.einsum.html)"
|
||||
]
|
||||
@@ -567,7 +567,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 5) Multi-head attention with PyTorch's scaled dot product attention and FlashAttention"
|
||||
"## 5. Multi-head attention with PyTorch's scaled dot product attention and FlashAttention"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -676,7 +676,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
|
||||
"## 6. PyTorch's scaled dot product attention without FlashAttention\n",
|
||||
"\n",
|
||||
"- This is similar to above, except that we disable FlashAttention by passing an explicit causal mask"
|
||||
]
|
||||
@@ -785,7 +785,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 7) Using PyTorch's torch.nn.MultiheadAttention"
|
||||
"## 7. Using PyTorch's torch.nn.MultiheadAttention"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -883,7 +883,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 8) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
|
||||
"## 8. Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -948,7 +948,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## 9) Using PyTorch's FlexAttention\n",
|
||||
"## 9. Using PyTorch's FlexAttention\n",
|
||||
"\n",
|
||||
"- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
|
||||
"- FlexAttention caveat: It currently doesn't support dropout\n",
|
||||
@@ -1108,7 +1108,18 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## Quick speed comparison (M3 Macbook Air CPU)"
|
||||
"## 10. Quick speed comparisons"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "992e28f4-a6b9-4dd3-9705-30d0b9f4b5f0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"### 10.1 Speed comparisons on M3 Macbook Air CPU"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1361,7 +1372,7 @@
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## Quick speed comparison (Nvidia A100 GPU)"
|
||||
"### 10.2 Quick speed comparison on Nvidia A100 GPU"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1643,7 +1654,18 @@
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Visualizations"
|
||||
"## 11. Visualizations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e6baf5ce-45ac-4e26-9523-5c32b82dc784",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<br>\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"### 11.1 Visualization utility functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1752,7 +1774,8 @@
|
||||
"id": "4df834dc"
|
||||
},
|
||||
"source": [
|
||||
"## Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
|
||||
" \n",
|
||||
"### 11.2 Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1834,7 +1857,7 @@
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
|
||||
"### 11.3 Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1920,7 +1943,7 @@
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
|
||||
"### 11.4 Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@ from llms_from_scratch.utils import import_definitions_from_notebook
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nb_imports():
|
||||
def import_notebook_defs():
|
||||
nb_dir = Path(__file__).resolve().parents[1]
|
||||
mod = import_definitions_from_notebook(nb_dir, "mha-implementations.ipynb")
|
||||
return mod
|
||||
@@ -31,12 +31,12 @@ def copy_weights(from_mha, to_mha):
|
||||
(1024, 512, 2, 4, 8, 789), # d_in > d_out
|
||||
],
|
||||
)
|
||||
def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, nb_imports):
|
||||
def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, import_notebook_defs):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
x = torch.randn(batch, seq_len, d_in)
|
||||
|
||||
mha_linear = nb_imports.Ch03_MHA(
|
||||
mha_linear = import_notebook_defs.Ch03_MHA(
|
||||
d_in=d_in,
|
||||
d_out=d_out,
|
||||
context_length=seq_len,
|
||||
@@ -45,7 +45,7 @@ def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, n
|
||||
qkv_bias=False,
|
||||
).eval()
|
||||
|
||||
mha_einsum = nb_imports.MHAEinsum(
|
||||
mha_einsum = import_notebook_defs.MHAEinsum(
|
||||
d_in=d_in,
|
||||
d_out=d_out,
|
||||
context_length=seq_len,
|
||||
|
||||
Reference in New Issue
Block a user