Files

32 lines
863 B
Python

import torch
import torchvision.datasets as datasets # Standard datasets
from tqdm import tqdm
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from model import WhisperFinetuning
from dataset import WhisperDataset
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DeepSpeedStrategy
torch.set_float32_matmul_precision("medium")
# things to add
lr = 1e-5
batch_size = 32
num_workers = 4
model = WhisperFinetuning(lr)
dm = WhisperDataset(data_dir="data/", batch_size=batch_size, num_workers=num_workers)
if __name__ == "__main__":
trainer = pl.Trainer(
max_epochs=1000,
accelerator="gpu",
devices=[0],
precision=16,
)
trainer.fit(model, dm)