mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
import torchvision.datasets as datasets
|
||
|
|
import torchvision.transforms as transforms
|
||
|
|
from torch import nn, optim
|
||
|
|
from torch.utils.data import DataLoader
|
||
|
|
from torch.utils.data import random_split
|
||
|
|
import pytorch_lightning as pl
|
||
|
|
|
||
|
|
|
||
|
|
class MnistDataModule(pl.LightningDataModule):
|
||
|
|
def __init__(self, data_dir, batch_size, num_workers):
|
||
|
|
super().__init__()
|
||
|
|
self.data_dir = data_dir
|
||
|
|
self.batch_size = batch_size
|
||
|
|
self.num_workers = num_workers
|
||
|
|
|
||
|
|
def prepare_data(self):
|
||
|
|
datasets.MNIST(self.data_dir, train=True, download=True)
|
||
|
|
datasets.MNIST(self.data_dir, train=False, download=True)
|
||
|
|
|
||
|
|
def setup(self, stage):
|
||
|
|
entire_dataset = datasets.MNIST(
|
||
|
|
root=self.data_dir,
|
||
|
|
train=True,
|
||
|
|
transform=transforms.ToTensor(),
|
||
|
|
download=False,
|
||
|
|
)
|
||
|
|
self.train_ds, self.val_ds = random_split(entire_dataset, [50000, 10000])
|
||
|
|
self.test_ds = datasets.MNIST(
|
||
|
|
root=self.data_dir,
|
||
|
|
train=False,
|
||
|
|
transform=transforms.ToTensor(),
|
||
|
|
download=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
def train_dataloader(self):
|
||
|
|
return DataLoader(
|
||
|
|
self.train_ds,
|
||
|
|
batch_size=self.batch_size,
|
||
|
|
num_workers=self.num_workers,
|
||
|
|
shuffle=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
def val_dataloader(self):
|
||
|
|
return DataLoader(
|
||
|
|
self.val_ds,
|
||
|
|
batch_size=self.batch_size,
|
||
|
|
num_workers=self.num_workers,
|
||
|
|
shuffle=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_dataloader(self):
|
||
|
|
return DataLoader(
|
||
|
|
self.test_ds,
|
||
|
|
batch_size=self.batch_size,
|
||
|
|
num_workers=self.num_workers,
|
||
|
|
shuffle=False,
|
||
|
|
)
|