mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
make softmax explicit
This commit is contained in:
@@ -974,7 +974,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 21,
|
||||
"id": "252b78c2-4404-483b-84fe-a412e55c16fc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1195,7 +1195,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- The following `generate_text_simple` function implements greedy decoding, which is a simple and fast method to generate text\n",
|
||||
"- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we don't have to compute the softmax function explicitly)\n",
|
||||
"- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we technically wouldn't even have to compute the softmax function explicitly)\n",
|
||||
"- In the next chapter, we will implement a more advanced `generate_text` function\n",
|
||||
"- The figure below depicts how the GPT model, given an input context, generates the next word token"
|
||||
]
|
||||
@@ -1210,7 +1210,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 41,
|
||||
"id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1232,8 +1232,11 @@
|
||||
" # (batch, n_token, vocab_size) becomes (batch, vocab_size)\n",
|
||||
" logits = logits[:, -1, :] \n",
|
||||
"\n",
|
||||
" # Apply softmax to get probabilities\n",
|
||||
" probas = torch.softmax(logits, dim=-1) # (batch, vocab_size)\n",
|
||||
"\n",
|
||||
" # Get the idx of the vocab entry with the highest logits value\n",
|
||||
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)\n",
|
||||
" idx_next = torch.argmax(probas, dim=-1, keepdim=True) # (batch, 1)\n",
|
||||
"\n",
|
||||
" # Append sampled index to the running sequence\n",
|
||||
" idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)\n",
|
||||
@@ -1261,7 +1264,42 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 54,
|
||||
"id": "bb3ffc8e-f95f-4a24-a978-939b8953ea3e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([-1.4929, 4.4812, -1.6093], grad_fn=<SliceBackward0>)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([ 0.0000, 0.0012, 0.0000, ..., 0.0001, 0.0000,\n",
|
||||
" 0.0000], grad_fn=<SoftmaxBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"b = logits[0, -1, :]\n",
|
||||
"b[0] = -1.4929\n",
|
||||
"b[1] = 4.4812\n",
|
||||
"b[2] = -1.6093\n",
|
||||
"\n",
|
||||
"print(b[:3])\n",
|
||||
"torch.softmax(b, dim=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1286,7 +1324,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 43,
|
||||
"id": "a72a9b60-de66-44cf-b2f9-1e638934ada4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 36 KiB |
Reference in New Issue
Block a user