mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 19:27:58 +00:00
32 lines
923 B
Python
32 lines
923 B
Python
import torch
|
|
import pytorch_lightning as pl
|
|
from model import NN
|
|
from dataset import MnistDataModule
|
|
import config
|
|
from callbacks import MyPrintingCallback, EarlyStopping
|
|
|
|
torch.set_float32_matmul_precision("medium") # to make lightning happy
|
|
|
|
if __name__ == "__main__":
|
|
model = NN(
|
|
input_size=config.INPUT_SIZE,
|
|
learning_rate=config.LEARNING_RATE,
|
|
num_classes=config.NUM_CLASSES,
|
|
)
|
|
dm = MnistDataModule(
|
|
data_dir=config.DATA_DIR,
|
|
batch_size=config.BATCH_SIZE,
|
|
num_workers=config.NUM_WORKERS,
|
|
)
|
|
trainer = pl.Trainer(
|
|
accelerator=config.ACCELERATOR,
|
|
devices=config.DEVICES,
|
|
min_epochs=1,
|
|
max_epochs=config.NUM_EPOCHS,
|
|
precision=config.PRECISION,
|
|
callbacks=[MyPrintingCallback(), EarlyStopping(monitor="val_loss")],
|
|
)
|
|
trainer.fit(model, dm)
|
|
trainer.validate(model, dm)
|
|
trainer.test(model, dm)
|