mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
import config
|
|
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
|
|
class DRDataset(Dataset):
|
|
def __init__(self, images_folder, path_to_csv, train=True, transform=None):
|
|
super().__init__()
|
|
self.data = pd.read_csv(path_to_csv)
|
|
self.images_folder = images_folder
|
|
self.image_files = os.listdir(images_folder)
|
|
self.transform = transform
|
|
self.train = train
|
|
|
|
def __len__(self):
|
|
return self.data.shape[0] if self.train else len(self.image_files)
|
|
|
|
def __getitem__(self, index):
|
|
if self.train:
|
|
image_file, label = self.data.iloc[index]
|
|
else:
|
|
# if test simply return -1 for label, I do this in order to
|
|
# re-use same dataset class for test set submission later on
|
|
image_file, label = self.image_files[index], -1
|
|
image_file = image_file.replace(".jpeg", "")
|
|
|
|
image = np.array(Image.open(os.path.join(self.images_folder, image_file+".jpeg")))
|
|
|
|
if self.transform:
|
|
image = self.transform(image=image)["image"]
|
|
|
|
return image, label, image_file
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Test if everything works ok
|
|
"""
|
|
dataset = DRDataset(
|
|
images_folder="../train/images_resized_650/",
|
|
path_to_csv="../train/trainLabels.csv",
|
|
transform=config.val_transforms,
|
|
)
|
|
loader = DataLoader(
|
|
dataset=dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True
|
|
)
|
|
|
|
for x, label, file in tqdm(loader):
|
|
print(x.shape)
|
|
print(label.shape)
|
|
import sys
|
|
sys.exit() |