mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 19:27:58 +00:00
44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import os
|
|
import numpy as np
|
|
import config
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from PIL import Image
|
|
|
|
|
|
class MyImageFolder(Dataset):
|
|
def __init__(self, root_dir):
|
|
super(MyImageFolder, self).__init__()
|
|
self.data = []
|
|
self.root_dir = root_dir
|
|
self.class_names = os.listdir(root_dir)
|
|
|
|
for index, name in enumerate(self.class_names):
|
|
files = os.listdir(os.path.join(root_dir, name))
|
|
self.data += list(zip(files, [index] * len(files)))
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
img_file, label = self.data[index]
|
|
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
|
|
|
|
image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
|
|
image = config.both_transforms(image=image)["image"]
|
|
high_res = config.highres_transform(image=image)["image"]
|
|
low_res = config.lowres_transform(image=image)["image"]
|
|
return low_res, high_res
|
|
|
|
|
|
def test():
|
|
dataset = MyImageFolder(root_dir="new_data/")
|
|
loader = DataLoader(dataset, batch_size=1, num_workers=8)
|
|
|
|
for low_res, high_res in loader:
|
|
print(low_res.shape)
|
|
print(high_res.shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test()
|