mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 19:27:58 +00:00
55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
import torch
|
|
import torchvision.datasets as datasets
|
|
import os
|
|
from torch.utils.data import WeightedRandomSampler, DataLoader
|
|
import torchvision.transforms as transforms
|
|
import torch.nn as nn
|
|
|
|
# Methods for dealing with imbalanced datasets:
|
|
# 1. Oversampling
|
|
# 2. Class weighting
|
|
|
|
def get_loader(root_dir, batch_size):
|
|
my_transforms = transforms.Compose(
|
|
[
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
]
|
|
)
|
|
|
|
dataset = datasets.ImageFolder(root=root_dir, transform=my_transforms)
|
|
class_weights = []
|
|
for root, subdir, files in os.walk(root_dir):
|
|
if len(files) > 0:
|
|
class_weights.append(1/len(files))
|
|
|
|
sample_weights = [0] * len(dataset)
|
|
|
|
for idx, (data, label) in enumerate(dataset):
|
|
class_weight = class_weights[label]
|
|
sample_weights[idx] = class_weight
|
|
|
|
sampler = WeightedRandomSampler(sample_weights, num_samples=
|
|
len(sample_weights), replacement=True)
|
|
|
|
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
|
|
return loader
|
|
|
|
|
|
def main():
|
|
loader = get_loader(root_dir="dataset", batch_size=8)
|
|
|
|
num_retrievers = 0
|
|
num_elkhounds = 0
|
|
for epoch in range(10):
|
|
for data, labels in loader:
|
|
num_retrievers += torch.sum(labels==0)
|
|
num_elkhounds += torch.sum(labels==1)
|
|
|
|
print(num_retrievers)
|
|
print(num_elkhounds)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|