From baa86179210ce3780ecab092b8b9092a9e49d064 Mon Sep 17 00:00:00 2001 From: rasbt Date: Sat, 10 Feb 2024 17:53:54 -0600 Subject: [PATCH] variable name fix --- ch04/01_main-chapter-code/ch04.ipynb | 30 +++++++++++++++------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/ch04/01_main-chapter-code/ch04.ipynb b/ch04/01_main-chapter-code/ch04.ipynb index a560e56..e3f62e7 100644 --- a/ch04/01_main-chapter-code/ch04.ipynb +++ b/ch04/01_main-chapter-code/ch04.ipynb @@ -185,15 +185,12 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "tensor([[ 6109, 3626, 6100, 345, 2651, 13],\n", - " [ 6109, 1110, 6622, 257, 11483, 13]])" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 6109, 3626, 6100, 345, 2651, 13],\n", + " [ 6109, 1110, 6622, 257, 11483, 13]])\n" + ] } ], "source": [ @@ -680,7 +677,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 16, "id": "928e7f7c-d0b1-499f-8d07-4cadb428a6f9", "metadata": {}, "outputs": [ @@ -741,7 +738,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 17, "id": "05473938-799c-49fd-86d4-8ed65f94fee6", "metadata": {}, "outputs": [ @@ -977,7 +974,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 31, "id": "252b78c2-4404-483b-84fe-a412e55c16fc", "metadata": {}, "outputs": [ @@ -985,6 +982,10 @@ "name": "stdout", "output_type": "stream", "text": [ + "Input batch:\n", + " tensor([[ 6109, 3626, 6100, 345, 2651, 13],\n", + " [ 6109, 1110, 6622, 257, 11483, 13]])\n", + "\n", "Output shape: torch.Size([2, 6, 50257])\n", "tensor([[[ 0.2237, 0.1153, 0.1121, ..., 0.1412, -0.0542, -0.3782],\n", " [ 0.5285, -0.0155, -0.5074, ..., -0.3225, 0.4875, -0.0612],\n", @@ -1005,10 +1006,11 @@ ], "source": [ "torch.manual_seed(123)\n", - "model = GPT(GPT_CONFIG_124M)\n", + "model = GPTModel(GPT_CONFIG_124M)\n", "\n", "out = model(batch)\n", - "print(\"Output shape:\", out.shape)\n", + "print(\"Input batch:\\n\", batch)\n", + "print(\"\\nOutput shape:\", out.shape)\n", "print(out)" ] },