mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
add lightning code, finetuning whisper, recommender system neural collaborative filtering
This commit is contained in:
59
ML/Pytorch/pytorch_lightning/7. Callbacks/dataset.py
Normal file
59
ML/Pytorch/pytorch_lightning/7. Callbacks/dataset.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import random_split
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class MnistDataModule(pl.LightningDataModule):
|
||||
def __init__(self, data_dir, batch_size, num_workers):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def prepare_data(self):
|
||||
datasets.MNIST(self.data_dir, train=True, download=True)
|
||||
datasets.MNIST(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage):
|
||||
entire_dataset = datasets.MNIST(
|
||||
root=self.data_dir,
|
||||
train=True,
|
||||
transform=transforms.ToTensor(),
|
||||
download=False,
|
||||
)
|
||||
self.train_ds, self.val_ds = random_split(entire_dataset, [50000, 10000])
|
||||
self.test_ds = datasets.MNIST(
|
||||
root=self.data_dir,
|
||||
train=False,
|
||||
transform=transforms.ToTensor(),
|
||||
download=False,
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_ds,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(
|
||||
self.val_ds,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
self.test_ds,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
Reference in New Issue
Block a user