mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
add lightning code, finetuning whisper, recommender system neural collaborative filtering
This commit is contained in:
31
ML/Pytorch/pytorch_lightning/7. Callbacks/train.py
Normal file
31
ML/Pytorch/pytorch_lightning/7. Callbacks/train.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from model import NN
|
||||
from dataset import MnistDataModule
|
||||
import config
|
||||
from callbacks import MyPrintingCallback, EarlyStopping
|
||||
|
||||
torch.set_float32_matmul_precision("medium") # to make lightning happy
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = NN(
|
||||
input_size=config.INPUT_SIZE,
|
||||
learning_rate=config.LEARNING_RATE,
|
||||
num_classes=config.NUM_CLASSES,
|
||||
)
|
||||
dm = MnistDataModule(
|
||||
data_dir=config.DATA_DIR,
|
||||
batch_size=config.BATCH_SIZE,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
)
|
||||
trainer = pl.Trainer(
|
||||
accelerator=config.ACCELERATOR,
|
||||
devices=config.DEVICES,
|
||||
min_epochs=1,
|
||||
max_epochs=config.NUM_EPOCHS,
|
||||
precision=config.PRECISION,
|
||||
callbacks=[MyPrintingCallback(), EarlyStopping(monitor="val_loss")],
|
||||
)
|
||||
trainer.fit(model, dm)
|
||||
trainer.validate(model, dm)
|
||||
trainer.test(model, dm)
|
||||
Reference in New Issue
Block a user