mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import torch
|
|
import pytorch_lightning as pl
|
|
from model import NN
|
|
from dataset import MnistDataModule
|
|
import config
|
|
from callbacks import MyPrintingCallback, EarlyStopping
|
|
from pytorch_lightning.loggers import TensorBoardLogger
|
|
from pytorch_lightning.profilers import PyTorchProfiler
|
|
|
|
torch.set_float32_matmul_precision("medium") # to make lightning happy
|
|
|
|
if __name__ == "__main__":
|
|
logger = TensorBoardLogger("tb_logs", name="mnist_model_v1")
|
|
profiler = PyTorchProfiler(
|
|
on_trace_ready=torch.profiler.tensorboard_trace_handler("tb_logs/profiler0"),
|
|
schedule=torch.profiler.schedule(skip_first=10, wait=1, warmup=1, active=20),
|
|
)
|
|
model = NN(
|
|
input_size=config.INPUT_SIZE,
|
|
learning_rate=config.LEARNING_RATE,
|
|
num_classes=config.NUM_CLASSES,
|
|
)
|
|
dm = MnistDataModule(
|
|
data_dir=config.DATA_DIR,
|
|
batch_size=config.BATCH_SIZE,
|
|
num_workers=config.NUM_WORKERS,
|
|
)
|
|
trainer = pl.Trainer(
|
|
profiler=profiler,
|
|
logger=logger,
|
|
accelerator=config.ACCELERATOR,
|
|
devices=config.DEVICES,
|
|
min_epochs=1,
|
|
max_epochs=config.NUM_EPOCHS,
|
|
precision=config.PRECISION,
|
|
callbacks=[MyPrintingCallback(), EarlyStopping(monitor="val_loss")],
|
|
)
|
|
trainer.fit(model, dm)
|
|
trainer.validate(model, dm)
|
|
trainer.test(model, dm)
|