mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
35 lines
1022 B
Python
35 lines
1022 B
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import os
|
||
|
|
from torch.utils.data import Dataset
|
||
|
|
from PIL import Image
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
class MyImageFolder(Dataset):
|
||
|
|
def __init__(self, root_dir, transform=None):
|
||
|
|
super(MyImageFolder, self).__init__()
|
||
|
|
self.data = []
|
||
|
|
self.root_dir = root_dir
|
||
|
|
self.transform = transform
|
||
|
|
self.class_names = os.listdir(root_dir)
|
||
|
|
|
||
|
|
for index, name in enumerate(self.class_names):
|
||
|
|
files = os.listdir(os.path.join(root_dir, name))
|
||
|
|
self.data += list(zip(files, [index]*len(files)))
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.data)
|
||
|
|
|
||
|
|
def __getitem__(self, index):
|
||
|
|
img_file, label = self.data[index]
|
||
|
|
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
|
||
|
|
image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
|
||
|
|
|
||
|
|
if self.transform is not None:
|
||
|
|
augmentations = self.transform(image=image)
|
||
|
|
image = augmentations["image"]
|
||
|
|
|
||
|
|
return image, label
|
||
|
|
|