import torch import pytorch_lightning as pl from model import NN from dataset import MnistDataModule import config if __name__ == "__main__": model = NN( input_size=config.INPUT_SIZE, learning_rate=config.LEARNING_RATE, num_classes=config.NUM_CLASSES, ) dm = MnistDataModule( data_dir=config.DATA_DIR, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, ) trainer = pl.Trainer( accelerator=config.ACCELERATOR, devices=config.DEVICES, min_epochs=1, max_epochs=3, precision=config.PRECISION, ) trainer.fit(model, dm) trainer.validate(model, dm) trainer.test(model, dm)