Files
Machine-Learning-Collection/ML/Pytorch/pytorch_lightning/6. Restructuring/train.py

28 lines
702 B
Python
Raw Normal View History

import torch
import pytorch_lightning as pl
from model import NN
from dataset import MnistDataModule
import config
if __name__ == "__main__":
model = NN(
input_size=config.INPUT_SIZE,
learning_rate=config.LEARNING_RATE,
num_classes=config.NUM_CLASSES,
)
dm = MnistDataModule(
data_dir=config.DATA_DIR,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
)
trainer = pl.Trainer(
accelerator=config.ACCELERATOR,
devices=config.DEVICES,
min_epochs=1,
max_epochs=3,
precision=config.PRECISION,
)
trainer.fit(model, dm)
trainer.validate(model, dm)
trainer.test(model, dm)