diff --git a/ch04/01_main-chapter-code/ch04.ipynb b/ch04/01_main-chapter-code/ch04.ipynb index 0c91305..81f931d 100644 --- a/ch04/01_main-chapter-code/ch04.ipynb +++ b/ch04/01_main-chapter-code/ch04.ipynb @@ -974,7 +974,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 21, "id": "252b78c2-4404-483b-84fe-a412e55c16fc", "metadata": {}, "outputs": [ @@ -1195,7 +1195,7 @@ "metadata": {}, "source": [ "- The following `generate_text_simple` function implements greedy decoding, which is a simple and fast method to generate text\n", - "- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we don't have to compute the softmax function explicitly)\n", + "- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we technically wouldn't even have to compute the softmax function explicitly)\n", "- In the next chapter, we will implement a more advanced `generate_text` function\n", "- The figure below depicts how the GPT model, given an input context, generates the next word token" ] @@ -1210,7 +1210,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 41, "id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6", "metadata": {}, "outputs": [], @@ -1232,8 +1232,11 @@ " # (batch, n_token, vocab_size) becomes (batch, vocab_size)\n", " logits = logits[:, -1, :] \n", "\n", + " # Apply softmax to get probabilities\n", + " probas = torch.softmax(logits, dim=-1) # (batch, vocab_size)\n", + "\n", " # Get the idx of the vocab entry with the highest logits value\n", - " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)\n", + " idx_next = torch.argmax(probas, dim=-1, keepdim=True) # (batch, 1)\n", "\n", " # Append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)\n", @@ -1261,7 +1264,42 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 54, + "id": "bb3ffc8e-f95f-4a24-a978-939b8953ea3e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-1.4929, 4.4812, -1.6093], grad_fn=)\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([ 0.0000, 0.0012, 0.0000, ..., 0.0001, 0.0000,\n", + " 0.0000], grad_fn=)" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b = logits[0, -1, :]\n", + "b[0] = -1.4929\n", + "b[1] = 4.4812\n", + "b[2] = -1.6093\n", + "\n", + "print(b[:3])\n", + "torch.softmax(b, dim=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, "id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3", "metadata": {}, "outputs": [ @@ -1286,7 +1324,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 43, "id": "a72a9b60-de66-44cf-b2f9-1e638934ada4", "metadata": {}, "outputs": [ diff --git a/ch04/01_main-chapter-code/figures/generate-text.webp b/ch04/01_main-chapter-code/figures/generate-text.webp index 0a59895..d7f38e6 100644 Binary files a/ch04/01_main-chapter-code/figures/generate-text.webp and b/ch04/01_main-chapter-code/figures/generate-text.webp differ