update pretrain progress bar tutorial

This commit is contained in:
Aladdin Persson
2022-12-19 16:29:48 +01:00
parent 058742e581
commit 3f53d68c4f
2 changed files with 9 additions and 26 deletions

View File

@@ -3,11 +3,9 @@ Shows a small example of how to load a pretrain model (VGG16) from PyTorch,
and modifies this to train on the CIFAR10 dataset. The same method generalizes and modifies this to train on the CIFAR10 dataset. The same method generalizes
well to other datasets, but the modifications to the network may need to be changed. well to other datasets, but the modifications to the network may need to be changed.
Video explanation: https://youtu.be/U4bHxEhMGNk
Got any questions leave a comment on youtube :)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com> Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-04-08 Initial coding * 2020-04-08 Initial coding
* 2022-12-19 Updated comments, minor code changes, made sure it works with latest PyTorch
""" """
@@ -22,8 +20,8 @@ from torch.utils.data import (
) # Gives easier dataset managment and creates mini batches ) # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset import torchvision.transforms as transforms # Transformations we can perform on our dataset
from tqdm import tqdm
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters # Hyperparameters
@@ -32,17 +30,8 @@ learning_rate = 1e-3
batch_size = 1024 batch_size = 1024
num_epochs = 5 num_epochs = 5
# Simple Identity class that let's input pass without changes
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
# Load pretrain model & modify it # Load pretrain model & modify it
model = torchvision.models.vgg16(pretrained=True) model = torchvision.models.vgg16(weights="DEFAULT")
# If you want to do finetuning then set requires_grad = False # If you want to do finetuning then set requires_grad = False
# Remove these two lines if you want to train entire model, # Remove these two lines if you want to train entire model,
@@ -50,7 +39,7 @@ model = torchvision.models.vgg16(pretrained=True)
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
model.avgpool = Identity() model.avgpool = nn.Identity()
model.classifier = nn.Sequential( model.classifier = nn.Sequential(
nn.Linear(512, 100), nn.ReLU(), nn.Linear(100, num_classes) nn.Linear(512, 100), nn.ReLU(), nn.Linear(100, num_classes)
) )
@@ -71,7 +60,7 @@ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs): for epoch in range(num_epochs):
losses = [] losses = []
for batch_idx, (data, targets) in enumerate(train_loader): for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
# Get data to cuda if possible # Get data to cuda if possible
data = data.to(device=device) data = data.to(device=device)
targets = targets.to(device=device) targets = targets.to(device=device)

View File

@@ -3,9 +3,7 @@ import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader from torch.utils.data import TensorDataset, DataLoader
# Create a simple toy dataset example, normally this # Create a simple toy dataset
# would be doing custom class with __getitem__ etc,
# which we have done in custom dataset tutorials
x = torch.randn((1000, 3, 224, 224)) x = torch.randn((1000, 3, 224, 224))
y = torch.randint(low=0, high=10, size=(1000, 1)) y = torch.randint(low=0, high=10, size=(1000, 1))
ds = TensorDataset(x, y) ds = TensorDataset(x, y)
@@ -13,12 +11,12 @@ loader = DataLoader(ds, batch_size=8)
model = nn.Sequential( model = nn.Sequential(
nn.Conv2d(3, 10, kernel_size=3, padding=1, stride=1), nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, padding=1, stride=1),
nn.Flatten(), nn.Flatten(),
nn.Linear(10 * 224 * 224, 10), nn.Linear(10 * 224 * 224, 10),
) )
NUM_EPOCHS = 100 NUM_EPOCHS = 10
for epoch in range(NUM_EPOCHS): for epoch in range(NUM_EPOCHS):
loop = tqdm(loader) loop = tqdm(loader)
for idx, (x, y) in enumerate(loop): for idx, (x, y) in enumerate(loop):
@@ -35,7 +33,3 @@ for epoch in range(NUM_EPOCHS):
loop.set_postfix(loss=torch.rand(1).item(), acc=torch.rand(1).item()) loop.set_postfix(loss=torch.rand(1).item(), acc=torch.rand(1).item())
# There you go. Hope it was useful :) # There you go. Hope it was useful :)