""" A simple walkthrough of how to code a fully connected neural network using the PyTorch library. For demonstration we train it on the very common MNIST dataset of handwritten digits. In this code we go through how to create the network as well as initialize a loss function, optimizer, check accuracy and more. Programmed by Aladdin Persson * 2020-04-08: Initial coding * 2021-03-24: Added more detailed comments also removed part of check_accuracy which would only work specifically on MNIST. """ # Imports import torch import torchvision # torch package for vision related things import torch.nn.functional as F # Parameterless functions, like (some) activation functions import torchvision.datasets as datasets # Standard datasets import torchvision.transforms as transforms # Transformations we can perform on our dataset for augmentation from torch import optim # For optimizers like SGD, Adam, etc. from torch import nn # All neural network modules from torch.utils.data import DataLoader # Gives easier dataset managment by creating mini batches etc. from tqdm import tqdm # For nice progress bar! # Here we create our simple neural network. For more details here we are subclassing and # inheriting from nn.Module, this is the most general way to create your networks and # allows for more flexibility. I encourage you to also check out nn.Sequential which # would be easier to use in this scenario but I wanted to show you something that # "always" works. class NN(nn.Module): def __init__(self, input_size, num_classes): super(NN, self).__init__() # Our first linear layer take input_size, in this case 784 nodes to 50 # and our second linear layer takes 50 to the num_classes we have, in # this case 10. self.fc1 = nn.Linear(input_size, 50) self.fc2 = nn.Linear(50, num_classes) def forward(self, x): """ x here is the mnist images and we run it through fc1, fc2 that we created above. we also add a ReLU activation function in between and for that (since it has no parameters) I recommend using nn.functional (F) """ x = F.relu(self.fc1(x)) x = self.fc2(x) return x # Set device cuda for GPU if it's available otherwise run on the CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyperparameters of our neural network which depends on the dataset, and # also just experimenting to see what works well (learning rate for example). input_size = 784 num_classes = 10 learning_rate = 0.001 batch_size = 64 num_epochs = 3 # Load Training and Test data train_dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True) # Initialize network model = NN(input_size=input_size, num_classes=num_classes).to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Train Network for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(tqdm(train_loader)): # Get data to cuda if possible data = data.to(device=device) targets = targets.to(device=device) # Get to correct shape data = data.reshape(data.shape[0], -1) # forward scores = model(data) loss = criterion(scores, targets) # backward optimizer.zero_grad() loss.backward() # gradient descent or adam step optimizer.step() # Check accuracy on training & test to see how good our model is def check_accuracy(loader, model): num_correct = 0 num_samples = 0 model.eval() with torch.no_grad(): for x, y in loader: x = x.to(device=device) y = y.to(device=device) x = x.reshape(x.shape[0], -1) scores = model(x) _, predictions = scores.max(1) num_correct += (predictions == y).sum() num_samples += predictions.size(0) model.train() return num_correct/num_samples print(f"Accuracy on training set: {check_accuracy(train_loader, model)*100:.2f}") print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")