Files
Machine-Learning-Collection/ML/Projects/Exploring_MNIST/utils/mnist_data.py
Aladdin Persson 65b8c80495 Initial commit
2021-01-30 21:49:15 +01:00

95 lines
2.9 KiB
Python

import numpy as np
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, SubsetRandomSampler
class mnist_data(object):
def __init__(
self,
shuffle,
transform_train,
transform_test,
num_workers=0,
create_validation_set=True,
batch_size=128,
validation_size=0.2,
random_seed=1,
):
self.shuffle = shuffle
self.validation_size = validation_size
self.transform_train = transform_train
self.transform_test = transform_test
self.random_seed = random_seed
self.create_validation_set = create_validation_set
self.batch_size = batch_size
self.num_workers = num_workers
def download_data(self):
mnist_trainset = datasets.MNIST(
root="./data", train=True, download=True, transform=self.transform_train
)
mnist_testset = datasets.MNIST(
root="./data", train=False, download=True, transform=self.transform_test
)
return mnist_trainset, mnist_testset
def create_validationset(self, mnist_trainset):
num_train = len(mnist_trainset)
indices = list(range(num_train))
split = int(self.validation_size * num_train)
if self.shuffle:
np.random.seed(self.random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
validation_sampler = SubsetRandomSampler(valid_idx)
loader_train = DataLoader(
dataset=mnist_trainset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.num_workers,
)
loader_validation = DataLoader(
dataset=mnist_trainset,
batch_size=self.batch_size,
sampler=validation_sampler,
num_workers=self.num_workers,
)
return loader_train, loader_validation
def main(self):
mnist_trainset, mnist_testset = self.download_data()
if self.create_validation_set:
loader_train, loader_validation = self.create_validationset(mnist_trainset)
loader_test = DataLoader(
dataset=mnist_testset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return loader_train, loader_validation, loader_test
else:
loader_train = DataLoader(
dataset=mnist_trainset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
loader_test = DataLoader(
dataset=mnist_testset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return loader_train, loader_test