revised some code examples

This commit is contained in:
Aladdin Persson
2021-03-24 21:57:40 +01:00
parent 74597aa8fd
commit e6c7f42c46
4 changed files with 90 additions and 117 deletions

View File

@@ -1,34 +1,33 @@
"""
Example code of a simple CNN network training on MNIST dataset.
The code is intended to show how to create a CNN network as well
as how to initialize loss, optimizer, etc. in a simple way to get
training to work with function that checks accuracy as well.
A simple walkthrough of how to code a convolutional neural network (CNN)
using the PyTorch library. For demonstration we train it on the very
common MNIST dataset of handwritten digits. In this code we go through
how to create the network as well as initialize a loss function, optimizer,
check accuracy and more.
Video explanation: https://youtu.be/wnK3uWv_WkU
Got any questions leave a comment on youtube :)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-04-08 Initial coding
Programmed by Aladdin Persson
* 2020-04-08: Initial coding
* 2021-03-24: More detailed comments and small revision of the code
"""
# Imports
import torch
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F # All functions that don't have any parameters
from torch.utils.data import (
DataLoader,
) # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
import torchvision # torch package for vision related things
import torch.nn.functional as F # Parameterless functions, like (some) activation functions
import torchvision.datasets as datasets # Standard datasets
import torchvision.transforms as transforms # Transformations we can perform on our dataset for augmentation
from torch import optim # For optimizers like SGD, Adam, etc.
from torch import nn # All neural network modules
from torch.utils.data import DataLoader # Gives easier dataset managment by creating mini batches etc.
from tqdm import tqdm # For nice progress bar!
# Simple CNN
class CNN(nn.Module):
def __init__(self, in_channels=1, num_classes=10):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
in_channels=in_channels,
out_channels=8,
kernel_size=(3, 3),
stride=(1, 1),
@@ -51,7 +50,6 @@ class CNN(nn.Module):
x = self.pool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
return x
@@ -59,24 +57,20 @@ class CNN(nn.Module):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
in_channel = 1
in_channels = 1
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 5
num_epochs = 3
# Load Data
train_dataset = datasets.MNIST(
root="dataset/", train=True, transform=transforms.ToTensor(), download=True
)
train_dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(
root="dataset/", train=False, transform=transforms.ToTensor(), download=True
)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
# Initialize network
model = CNN().to(device)
model = CNN(in_channels=in_channels, num_classes=num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
@@ -84,7 +78,7 @@ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Train Network
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(train_loader):
for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
# Get data to cuda if possible
data = data.to(device=device)
targets = targets.to(device=device)
@@ -101,14 +95,7 @@ for epoch in range(num_epochs):
optimizer.step()
# Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):
if loader.dataset.train:
print("Checking accuracy on training data")
else:
print("Checking accuracy on test data")
num_correct = 0
num_samples = 0
model.eval()
@@ -123,12 +110,10 @@ def check_accuracy(loader, model):
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
print(
f"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}"
)
model.train()
return num_correct/num_samples
check_accuracy(train_loader, model)
check_accuracy(test_loader, model)
print(f"Accuracy on training set: {check_accuracy(train_loader, model)*100:.2f}")
print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")