mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
28 lines
702 B
Python
28 lines
702 B
Python
|
|
import torch
|
||
|
|
import pytorch_lightning as pl
|
||
|
|
from model import NN
|
||
|
|
from dataset import MnistDataModule
|
||
|
|
import config
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
model = NN(
|
||
|
|
input_size=config.INPUT_SIZE,
|
||
|
|
learning_rate=config.LEARNING_RATE,
|
||
|
|
num_classes=config.NUM_CLASSES,
|
||
|
|
)
|
||
|
|
dm = MnistDataModule(
|
||
|
|
data_dir=config.DATA_DIR,
|
||
|
|
batch_size=config.BATCH_SIZE,
|
||
|
|
num_workers=config.NUM_WORKERS,
|
||
|
|
)
|
||
|
|
trainer = pl.Trainer(
|
||
|
|
accelerator=config.ACCELERATOR,
|
||
|
|
devices=config.DEVICES,
|
||
|
|
min_epochs=1,
|
||
|
|
max_epochs=3,
|
||
|
|
precision=config.PRECISION,
|
||
|
|
)
|
||
|
|
trainer.fit(model, dm)
|
||
|
|
trainer.validate(model, dm)
|
||
|
|
trainer.test(model, dm)
|