2021-01-30 21:49:15 +01:00
|
|
|
"""
|
|
|
|
|
An implementation of LeNet CNN architecture.
|
|
|
|
|
|
|
|
|
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
|
|
|
|
* 2020-04-05 Initial coding
|
2022-12-20 12:13:12 +01:00
|
|
|
* 2022-12-20 Update comments, code revision, checked still works with latest PyTorch version
|
2021-01-30 21:49:15 +01:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LeNet(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(LeNet, self).__init__()
|
|
|
|
|
self.relu = nn.ReLU()
|
2022-12-20 12:13:12 +01:00
|
|
|
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
2021-01-30 21:49:15 +01:00
|
|
|
self.conv1 = nn.Conv2d(
|
|
|
|
|
in_channels=1,
|
|
|
|
|
out_channels=6,
|
2022-12-20 12:13:12 +01:00
|
|
|
kernel_size=5,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0,
|
2021-01-30 21:49:15 +01:00
|
|
|
)
|
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
|
|
|
in_channels=6,
|
|
|
|
|
out_channels=16,
|
2022-12-20 12:13:12 +01:00
|
|
|
kernel_size=5,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0,
|
2021-01-30 21:49:15 +01:00
|
|
|
)
|
|
|
|
|
self.conv3 = nn.Conv2d(
|
|
|
|
|
in_channels=16,
|
|
|
|
|
out_channels=120,
|
2022-12-20 12:13:12 +01:00
|
|
|
kernel_size=5,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0,
|
2021-01-30 21:49:15 +01:00
|
|
|
)
|
|
|
|
|
self.linear1 = nn.Linear(120, 84)
|
|
|
|
|
self.linear2 = nn.Linear(84, 10)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.relu(self.conv1(x))
|
|
|
|
|
x = self.pool(x)
|
|
|
|
|
x = self.relu(self.conv2(x))
|
|
|
|
|
x = self.pool(x)
|
|
|
|
|
x = self.relu(
|
|
|
|
|
self.conv3(x)
|
|
|
|
|
) # num_examples x 120 x 1 x 1 --> num_examples x 120
|
|
|
|
|
x = x.reshape(x.shape[0], -1)
|
|
|
|
|
x = self.relu(self.linear1(x))
|
|
|
|
|
x = self.linear2(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_lenet():
|
|
|
|
|
x = torch.randn(64, 1, 32, 32)
|
|
|
|
|
model = LeNet()
|
|
|
|
|
return model(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
out = test_lenet()
|
2022-12-20 12:13:12 +01:00
|
|
|
print(out.shape)
|