Files
Machine-Learning-Collection/ML/Kaggles/DiabeticRetinopathy/dataset.py

56 lines
1.7 KiB
Python
Raw Normal View History

2021-05-30 16:24:52 +02:00
import config
import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
class DRDataset(Dataset):
def __init__(self, images_folder, path_to_csv, train=True, transform=None):
super().__init__()
self.data = pd.read_csv(path_to_csv)
self.images_folder = images_folder
self.image_files = os.listdir(images_folder)
self.transform = transform
self.train = train
def __len__(self):
return self.data.shape[0] if self.train else len(self.image_files)
def __getitem__(self, index):
if self.train:
image_file, label = self.data.iloc[index]
else:
# if test simply return -1 for label, I do this in order to
# re-use same dataset class for test set submission later on
image_file, label = self.image_files[index], -1
image_file = image_file.replace(".jpeg", "")
image = np.array(Image.open(os.path.join(self.images_folder, image_file+".jpeg")))
if self.transform:
image = self.transform(image=image)["image"]
return image, label, image_file
if __name__ == "__main__":
"""
Test if everything works ok
"""
dataset = DRDataset(
images_folder="../train/images_resized_650/",
path_to_csv="../train/trainLabels.csv",
transform=config.val_transforms,
)
loader = DataLoader(
dataset=dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True
)
for x, label, file in tqdm(loader):
print(x.shape)
print(label.shape)
import sys
sys.exit()