""" An implementation of LeNet CNN architecture. Programmed by Aladdin Persson * 2020-04-05 Initial coding * 2022-12-20 Update comments, code revision, checked still works with latest PyTorch version """ import torch import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.relu = nn.ReLU() self.pool = nn.AvgPool2d(kernel_size=2, stride=2) self.conv1 = nn.Conv2d( in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0, ) self.conv2 = nn.Conv2d( in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0, ) self.conv3 = nn.Conv2d( in_channels=16, out_channels=120, kernel_size=5, stride=1, padding=0, ) self.linear1 = nn.Linear(120, 84) self.linear2 = nn.Linear(84, 10) def forward(self, x): x = self.relu(self.conv1(x)) x = self.pool(x) x = self.relu(self.conv2(x)) x = self.pool(x) x = self.relu( self.conv3(x) ) # num_examples x 120 x 1 x 1 --> num_examples x 120 x = x.reshape(x.shape[0], -1) x = self.relu(self.linear1(x)) x = self.linear2(x) return x def test_lenet(): x = torch.randn(64, 1, 32, 32) model = LeNet() return model(x) if __name__ == "__main__": out = test_lenet() print(out.shape)