mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
191 lines
5.7 KiB
Python
191 lines
5.7 KiB
Python
"""
|
|
Simple pytorch lightning example
|
|
"""
|
|
|
|
# Imports
|
|
import torch
|
|
import torch.nn.functional as F # Parameterless functions, like (some) activation functions
|
|
import torchvision.datasets as datasets # Standard datasets
|
|
import torchvision.transforms as transforms # Transformations we can perform on our dataset for augmentation
|
|
from torch import optim # For optimizers like SGD, Adam, etc.
|
|
from torch import nn # All neural network modules
|
|
from torch.utils.data import (
|
|
DataLoader,
|
|
) # Gives easier dataset managment by creating mini batches etc.
|
|
from tqdm import tqdm # For nice progress bar!
|
|
import pytorch_lightning as pl
|
|
import torchmetrics
|
|
from pytorch_lightning.callbacks import Callback, EarlyStopping
|
|
|
|
|
|
precision = "medium"
|
|
torch.set_float32_matmul_precision(precision)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
## use 20% of training data for validation
|
|
# train_set_size = int(len(train_dataset) * 0.8)
|
|
# valid_set_size = len(train_dataset) - train_set_size
|
|
#
|
|
## split the train set into two
|
|
# seed = torch.Generator().manual_seed(42)
|
|
# train_dataset, val_dataset = torch.utils.data.random_split(
|
|
# train_dataset, [train_set_size, valid_set_size], generator=seed
|
|
# )
|
|
|
|
|
|
class CNNLightning(pl.LightningModule):
|
|
def __init__(self, lr=3e-4, in_channels=1, num_classes=10):
|
|
super().__init__()
|
|
self.lr = lr
|
|
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
|
|
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=8,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.conv2 = nn.Conv2d(
|
|
in_channels=8,
|
|
out_channels=16,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
self.fc1 = nn.Linear(16 * 7 * 7, num_classes)
|
|
self.lr = lr
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self._common_step(x, batch_idx)
|
|
loss = criterion(y_hat, y)
|
|
accuracy = self.train_acc(y_hat, y)
|
|
self.log(
|
|
"train_acc_step",
|
|
self.train_acc,
|
|
on_step=True,
|
|
on_epoch=False,
|
|
prog_bar=True,
|
|
)
|
|
return loss
|
|
|
|
def training_epoch_end(self, outputs):
|
|
self.train_acc.reset()
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self._common_step(x, batch_idx)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
accuracy = self.test_acc(y_hat, y)
|
|
self.log("test_loss", loss, on_step=True)
|
|
self.log("test_acc", accuracy, on_step=True)
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self._common_step(x, batch_idx)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
accuracy = self.test_acc(y_hat, y)
|
|
self.log("val_loss", loss, on_step=True)
|
|
self.log("val_acc", accuracy, on_step=True)
|
|
|
|
def predict_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self._common_step(x)
|
|
return y_hat
|
|
|
|
def _common_step(self, x, batch_idx):
|
|
x = self.pool(F.relu(self.conv1(x)))
|
|
x = self.pool(F.relu(self.conv2(x)))
|
|
x = x.reshape(x.shape[0], -1)
|
|
y_hat = self.fc1(x)
|
|
return y_hat
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = optim.Adam(self.parameters(), lr=self.lr)
|
|
return optimizer
|
|
|
|
|
|
class MNISTDataModule(pl.LightningDataModule):
|
|
def __init__(self, batch_size=512):
|
|
super().__init__()
|
|
self.batch_size = batch_size
|
|
|
|
def setup(self, stage):
|
|
mnist_full = train_dataset = datasets.MNIST(
|
|
root="dataset/", train=True, transform=transforms.ToTensor(), download=True
|
|
)
|
|
self.mnist_test = datasets.MNIST(
|
|
root="dataset/", train=False, transform=transforms.ToTensor(), download=True
|
|
)
|
|
self.mnist_train, self.mnist_val = torch.utils.data.random_split(
|
|
mnist_full, [55000, 5000]
|
|
)
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(
|
|
self.mnist_train,
|
|
batch_size=self.batch_size,
|
|
num_workers=6,
|
|
shuffle=True,
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(
|
|
self.mnist_val, batch_size=self.batch_size, num_workers=2, shuffle=False
|
|
)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(
|
|
self.mnist_test, batch_size=self.batch_size, num_workers=2, shuffle=False
|
|
)
|
|
|
|
|
|
class MyPrintingCallback(Callback):
|
|
def on_train_start(self, trainer, pl_module):
|
|
print("Training is starting")
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
print("Training is ending")
|
|
|
|
|
|
# Set device
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Load Data
|
|
if __name__ == "__main__":
|
|
# Initialize network
|
|
model_lightning = CNNLightning()
|
|
|
|
trainer = pl.Trainer(
|
|
#fast_dev_run=True,
|
|
# overfit_batches=3,
|
|
max_epochs=5,
|
|
precision=16,
|
|
accelerator="gpu",
|
|
devices=[0,1],
|
|
callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
|
|
auto_lr_find=True,
|
|
enable_model_summary=True,
|
|
profiler="simple",
|
|
strategy="deepspeed_stage_1",
|
|
# accumulate_grad_batches=2,
|
|
# auto_scale_batch_size="binsearch",
|
|
# log_every_n_steps=1,
|
|
)
|
|
|
|
dm = MNISTDataModule()
|
|
|
|
# trainer tune first to find best batch size and lr
|
|
trainer.tune(model_lightning, dm)
|
|
|
|
trainer.fit(
|
|
model=model_lightning,
|
|
datamodule=dm,
|
|
)
|
|
|
|
# test model on test loader from LightningDataModule
|
|
trainer.test(model=model_lightning, datamodule=dm)
|