diff --git a/ML/Pytorch/CNN_architectures/pytorch_inceptionet.py b/ML/Pytorch/CNN_architectures/pytorch_inceptionet.py index c70cde7..38faef9 100644 --- a/ML/Pytorch/CNN_architectures/pytorch_inceptionet.py +++ b/ML/Pytorch/CNN_architectures/pytorch_inceptionet.py @@ -51,7 +51,7 @@ class GoogLeNet(nn.Module): self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1) self.dropout = nn.Dropout(p=0.4) - self.fc1 = nn.Linear(1024, 1000) + self.fc1 = nn.Linear(1024, num_classes) if self.aux_logits: self.aux1 = InceptionAux(512, num_classes)