mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
make a new example for shortcut connections
This commit is contained in:
@@ -738,43 +738,124 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 19,
|
||||
"id": "05473938-799c-49fd-86d4-8ed65f94fee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ExampleDeepNeuralNetwork(nn.Module):\n",
|
||||
" def __init__(self, layer_sizes, use_shortcut):\n",
|
||||
" super().__init__()\n",
|
||||
" self.use_shortcut = use_shortcut\n",
|
||||
" self.layers = nn.ModuleList([\n",
|
||||
" nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), GELU()),\n",
|
||||
" nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), GELU()),\n",
|
||||
" nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), GELU()),\n",
|
||||
" nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), GELU()),\n",
|
||||
" nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), GELU())\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" for layer in self.layers:\n",
|
||||
" # Compute the output of the current layer\n",
|
||||
" layer_output = layer(x)\n",
|
||||
" # Check if shortcut can be applied\n",
|
||||
" if self.use_shortcut and x.size() == layer_output.size():\n",
|
||||
" x = x + layer_output\n",
|
||||
" else:\n",
|
||||
" x = layer_output\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_gradients(model, x):\n",
|
||||
" # Forward pass\n",
|
||||
" output = model(x)\n",
|
||||
" target = torch.tensor([[0.]])\n",
|
||||
"\n",
|
||||
" # Calculate loss based on how close the target\n",
|
||||
" # and output are\n",
|
||||
" loss = nn.MSELoss()\n",
|
||||
" loss = loss(output, target)\n",
|
||||
" \n",
|
||||
" # Backward pass to calculate the gradients\n",
|
||||
" loss.backward()\n",
|
||||
"\n",
|
||||
" for name, param in model.named_parameters():\n",
|
||||
" if 'weight' in name:\n",
|
||||
" # Print the mean absolute gradient of the weights\n",
|
||||
" print(f\"{name} has gradient mean of {param.grad.abs().mean().item()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b39bf277-b3db-4bb1-84ce-7a20caff1011",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Let's print the gradient values first **without** shortcut connections:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "c75f43cc-6923-4018-b980-26023086572c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Before shortcut: tensor([[0.0950, 0.0634, 0.3361]], grad_fn=<MulBackward0>)\n",
|
||||
"After shortcut: tensor([[-0.9050, 1.0634, 2.3361]], grad_fn=<AddBackward0>)\n",
|
||||
"Final network output: tensor([[0.2427]], grad_fn=<AddmmBackward0>)\n"
|
||||
"layers.0.0.weight has gradient mean of 0.00020173587836325169\n",
|
||||
"layers.1.0.weight has gradient mean of 0.0001201116101583466\n",
|
||||
"layers.2.0.weight has gradient mean of 0.0007152041653171182\n",
|
||||
"layers.3.0.weight has gradient mean of 0.001398873864673078\n",
|
||||
"layers.4.0.weight has gradient mean of 0.005049646366387606\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class ExampleWithShortcut(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.fc1 = nn.Linear(3, 3)\n",
|
||||
" self.fc2 = nn.Linear(3, 3)\n",
|
||||
" self.fc3 = nn.Linear(3, 1)\n",
|
||||
" self.gelu = GELU()\n",
|
||||
"layer_sizes = [3, 3, 3, 3, 3, 1] \n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" shortcut = x\n",
|
||||
" x = self.gelu(self.fc1(x))\n",
|
||||
" x = self.gelu(self.fc2(x))\n",
|
||||
" print(\"Before shortcut:\", x)\n",
|
||||
" x = x + shortcut\n",
|
||||
" print(\"After shortcut:\", x)\n",
|
||||
" x = self.fc3(x)\n",
|
||||
" return x\n",
|
||||
"sample_input = torch.tensor([[1., 0., -1.]])\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"ex_short = ExampleWithShortcut()\n",
|
||||
"inputs = torch.tensor([[-1., 1., 2.]])\n",
|
||||
"print(\"Final network output:\", ex_short(inputs))"
|
||||
"model_without_shortcut = ExampleDeepNeuralNetwork(\n",
|
||||
" layer_sizes, use_shortcut=False\n",
|
||||
")\n",
|
||||
"print_gradients(model_without_shortcut, sample_input)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "837fd5d4-7345-4663-97f5-38f19dfde621",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Next, let's print the gradient values **with** shortcut connections:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "11b7c0c2-f9dd-4dd5-b096-a05c48c5f6d6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"layers.0.0.weight has gradient mean of 0.22169792652130127\n",
|
||||
"layers.1.0.weight has gradient mean of 0.20694105327129364\n",
|
||||
"layers.2.0.weight has gradient mean of 0.32896995544433594\n",
|
||||
"layers.3.0.weight has gradient mean of 0.2665732502937317\n",
|
||||
"layers.4.0.weight has gradient mean of 1.3258541822433472\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"model_with_shortcut = ExampleDeepNeuralNetwork(\n",
|
||||
" layer_sizes, use_shortcut=True\n",
|
||||
")\n",
|
||||
"print_gradients(model_with_shortcut, sample_input)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -782,6 +863,7 @@
|
||||
"id": "79ff783a-46f0-49c5-a7a9-26a525764b6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- As we can see based on the output above, shortcut connections prevent the gradients from vanishing in the early layers (towards `layer.0`)\n",
|
||||
"- We will use this concept of a shortcut connection next when we implement a transformer block"
|
||||
]
|
||||
},
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 32 KiB |
Reference in New Issue
Block a user