Files
Machine-Learning-Collection/ML/Pytorch/CNN_architectures/lenet5_pytorch.py

64 lines
1.6 KiB
Python

"""
An implementation of LeNet CNN architecture.
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-04-05 Initial coding
* 2022-12-20 Update comments, code revision, checked still works with latest PyTorch version
"""
import torch
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.relu = nn.ReLU()
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=6,
kernel_size=5,
stride=1,
padding=0,
)
self.conv2 = nn.Conv2d(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
)
self.conv3 = nn.Conv2d(
in_channels=16,
out_channels=120,
kernel_size=5,
stride=1,
padding=0,
)
self.linear1 = nn.Linear(120, 84)
self.linear2 = nn.Linear(84, 10)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.pool(x)
x = self.relu(self.conv2(x))
x = self.pool(x)
x = self.relu(
self.conv3(x)
) # num_examples x 120 x 1 x 1 --> num_examples x 120
x = x.reshape(x.shape[0], -1)
x = self.relu(self.linear1(x))
x = self.linear2(x)
return x
def test_lenet():
x = torch.randn(64, 1, 32, 32)
model = LeNet()
return model(x)
if __name__ == "__main__":
out = test_lenet()
print(out.shape)