mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
add imbalanced classes video code and kaggle cat vs dog
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyImageFolder(Dataset):
|
||||
def __init__(self, root_dir, transform=None):
|
||||
super(MyImageFolder, self).__init__()
|
||||
self.data = []
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.class_names = os.listdir(root_dir)
|
||||
|
||||
for index, name in enumerate(self.class_names):
|
||||
files = os.listdir(os.path.join(root_dir, name))
|
||||
self.data += list(zip(files, [index]*len(files)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_file, label = self.data[index]
|
||||
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
|
||||
image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
|
||||
|
||||
if self.transform is not None:
|
||||
augmentations = self.transform(image=image)
|
||||
image = augmentations["image"]
|
||||
|
||||
return image, label
|
||||
|
||||
Reference in New Issue
Block a user