Files
Machine-Learning-Collection/ML/Pytorch/pytorch_lightning/8. Logging Tensorboard/train.py

35 lines
1.0 KiB
Python

import torch
import pytorch_lightning as pl
from model import NN
from dataset import MnistDataModule
import config
from callbacks import MyPrintingCallback, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
torch.set_float32_matmul_precision("medium") # to make lightning happy
if __name__ == "__main__":
logger = TensorBoardLogger("tb_logs", name="mnist_model_v0")
model = NN(
input_size=config.INPUT_SIZE,
learning_rate=config.LEARNING_RATE,
num_classes=config.NUM_CLASSES,
)
dm = MnistDataModule(
data_dir=config.DATA_DIR,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
)
trainer = pl.Trainer(
logger=logger,
accelerator=config.ACCELERATOR,
devices=config.DEVICES,
min_epochs=1,
max_epochs=config.NUM_EPOCHS,
precision=config.PRECISION,
callbacks=[MyPrintingCallback(), EarlyStopping(monitor="val_loss")],
)
trainer.fit(model, dm)
trainer.validate(model, dm)
trainer.test(model, dm)