updated basic tutorials, better comments, code revision, checked it works with latest pytorch version

This commit is contained in:
Aladdin Persson
2022-12-19 23:39:48 +01:00
parent 3f53d68c4f
commit cd607c395c
14 changed files with 162 additions and 88 deletions

View File

@@ -3,6 +3,7 @@ import albumentations as A
import numpy as np
from utils import plot_examples
from PIL import Image
from tqdm import tqdm
image = Image.open("images/elon.jpeg")
@@ -14,18 +15,20 @@ transform = A.Compose(
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.9),
A.OneOf([
A.Blur(blur_limit=3, p=0.5),
A.ColorJitter(p=0.5),
], p=1.0),
A.OneOf(
[
A.Blur(blur_limit=3, p=0.5),
A.ColorJitter(p=0.5),
],
p=1.0,
),
]
)
images_list = [image]
image = np.array(image)
for i in range(15):
for i in tqdm(range(15)):
augmentations = transform(image=image)
augmented_img = augmentations["image"]
images_list.append(augmented_img)
plot_examples(images_list)

View File

@@ -8,6 +8,7 @@ from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import os
class ImageFolder(Dataset):
def __init__(self, root_dir, transform=None):
super(ImageFolder, self).__init__()
@@ -18,7 +19,7 @@ class ImageFolder(Dataset):
for index, name in enumerate(self.class_names):
files = os.listdir(os.path.join(root_dir, name))
self.data += list(zip(files, [index]*len(files)))
self.data += list(zip(files, [index] * len(files)))
def __len__(self):
return len(self.data)
@@ -43,10 +44,13 @@ transform = A.Compose(
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.9),
A.OneOf([
A.Blur(blur_limit=3, p=0.5),
A.ColorJitter(p=0.5),
], p=1.0),
A.OneOf(
[
A.Blur(blur_limit=3, p=0.5),
A.ColorJitter(p=0.5),
],
p=1.0,
),
A.Normalize(
mean=[0, 0, 0],
std=[1, 1, 1],
@@ -58,5 +62,5 @@ transform = A.Compose(
dataset = ImageFolder(root_dir="cat_dogs", transform=transform)
for x,y in dataset:
for x, y in dataset:
print(x.shape)

View File

@@ -8,7 +8,7 @@ import albumentations as A
def visualize(image):
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.axis("off")
plt.imshow(image)
plt.show()
@@ -22,7 +22,7 @@ def plot_examples(images, bboxes=None):
if bboxes is not None:
img = visualize_bbox(images[i - 1], bboxes[i - 1], class_name="Elon")
else:
img = images[i-1]
img = images[i - 1]
fig.add_subplot(rows, columns, i)
plt.imshow(img)
plt.show()