make a new example for shortcut connections

This commit is contained in:
rasbt
2024-02-15 19:34:12 -06:00
parent 250e6306e2
commit 557ddfc684
2 changed files with 105 additions and 23 deletions

View File

@@ -738,43 +738,124 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "05473938-799c-49fd-86d4-8ed65f94fee6",
"metadata": {},
"outputs": [],
"source": [
"class ExampleDeepNeuralNetwork(nn.Module):\n",
" def __init__(self, layer_sizes, use_shortcut):\n",
" super().__init__()\n",
" self.use_shortcut = use_shortcut\n",
" self.layers = nn.ModuleList([\n",
" nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), GELU()),\n",
" nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), GELU()),\n",
" nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), GELU()),\n",
" nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), GELU()),\n",
" nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), GELU())\n",
" ])\n",
"\n",
" def forward(self, x):\n",
" for layer in self.layers:\n",
" # Compute the output of the current layer\n",
" layer_output = layer(x)\n",
" # Check if shortcut can be applied\n",
" if self.use_shortcut and x.size() == layer_output.size():\n",
" x = x + layer_output\n",
" else:\n",
" x = layer_output\n",
" return x\n",
"\n",
"\n",
"def print_gradients(model, x):\n",
" # Forward pass\n",
" output = model(x)\n",
" target = torch.tensor([[0.]])\n",
"\n",
" # Calculate loss based on how close the target\n",
" # and output are\n",
" loss = nn.MSELoss()\n",
" loss = loss(output, target)\n",
" \n",
" # Backward pass to calculate the gradients\n",
" loss.backward()\n",
"\n",
" for name, param in model.named_parameters():\n",
" if 'weight' in name:\n",
" # Print the mean absolute gradient of the weights\n",
" print(f\"{name} has gradient mean of {param.grad.abs().mean().item()}\")"
]
},
{
"cell_type": "markdown",
"id": "b39bf277-b3db-4bb1-84ce-7a20caff1011",
"metadata": {},
"source": [
"- Let's print the gradient values first **without** shortcut connections:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "c75f43cc-6923-4018-b980-26023086572c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before shortcut: tensor([[0.0950, 0.0634, 0.3361]], grad_fn=<MulBackward0>)\n",
"After shortcut: tensor([[-0.9050, 1.0634, 2.3361]], grad_fn=<AddBackward0>)\n",
"Final network output: tensor([[0.2427]], grad_fn=<AddmmBackward0>)\n"
"layers.0.0.weight has gradient mean of 0.00020173587836325169\n",
"layers.1.0.weight has gradient mean of 0.0001201116101583466\n",
"layers.2.0.weight has gradient mean of 0.0007152041653171182\n",
"layers.3.0.weight has gradient mean of 0.001398873864673078\n",
"layers.4.0.weight has gradient mean of 0.005049646366387606\n"
]
}
],
"source": [
"class ExampleWithShortcut(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(3, 3)\n",
" self.fc2 = nn.Linear(3, 3)\n",
" self.fc3 = nn.Linear(3, 1)\n",
" self.gelu = GELU()\n",
"layer_sizes = [3, 3, 3, 3, 3, 1] \n",
"\n",
" def forward(self, x):\n",
" shortcut = x\n",
" x = self.gelu(self.fc1(x))\n",
" x = self.gelu(self.fc2(x))\n",
" print(\"Before shortcut:\", x)\n",
" x = x + shortcut\n",
" print(\"After shortcut:\", x)\n",
" x = self.fc3(x)\n",
" return x\n",
"sample_input = torch.tensor([[1., 0., -1.]])\n",
"\n",
"torch.manual_seed(123)\n",
"ex_short = ExampleWithShortcut()\n",
"inputs = torch.tensor([[-1., 1., 2.]])\n",
"print(\"Final network output:\", ex_short(inputs))"
"model_without_shortcut = ExampleDeepNeuralNetwork(\n",
" layer_sizes, use_shortcut=False\n",
")\n",
"print_gradients(model_without_shortcut, sample_input)"
]
},
{
"cell_type": "markdown",
"id": "837fd5d4-7345-4663-97f5-38f19dfde621",
"metadata": {},
"source": [
"- Next, let's print the gradient values **with** shortcut connections:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "11b7c0c2-f9dd-4dd5-b096-a05c48c5f6d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"layers.0.0.weight has gradient mean of 0.22169792652130127\n",
"layers.1.0.weight has gradient mean of 0.20694105327129364\n",
"layers.2.0.weight has gradient mean of 0.32896995544433594\n",
"layers.3.0.weight has gradient mean of 0.2665732502937317\n",
"layers.4.0.weight has gradient mean of 1.3258541822433472\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"model_with_shortcut = ExampleDeepNeuralNetwork(\n",
" layer_sizes, use_shortcut=True\n",
")\n",
"print_gradients(model_with_shortcut, sample_input)"
]
},
{
@@ -782,6 +863,7 @@
"id": "79ff783a-46f0-49c5-a7a9-26a525764b6e",
"metadata": {},
"source": [
"- As we can see based on the output above, shortcut connections prevent the gradients from vanishing in the early layers (towards `layer.0`)\n",
"- We will use this concept of a shortcut connection next when we implement a transformer block"
]
},

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 32 KiB