diff --git a/ML/Pytorch/GANs/Pix2Pix/config.py b/ML/Pytorch/GANs/Pix2Pix/config.py index e5c5e3f..870a3d0 100644 --- a/ML/Pytorch/GANs/Pix2Pix/config.py +++ b/ML/Pytorch/GANs/Pix2Pix/config.py @@ -2,7 +2,9 @@ import torch import albumentations as A from albumentations.pytorch import ToTensorV2 -DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +TRAIN_DIR = "data/train" +VAL_DIR = "data/val" LEARNING_RATE = 2e-4 BATCH_SIZE = 16 NUM_WORKERS = 2 @@ -22,6 +24,8 @@ both_transform = A.Compose( transform_only_input = A.Compose( [ + A.HorizontalFlip(p=0.5), + A.ColorJitter(p=0.2), A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,), ToTensorV2(), ] diff --git a/ML/Pytorch/GANs/Pix2Pix/train.py b/ML/Pytorch/GANs/Pix2Pix/train.py index 45017b0..b4362a1 100644 --- a/ML/Pytorch/GANs/Pix2Pix/train.py +++ b/ML/Pytorch/GANs/Pix2Pix/train.py @@ -71,7 +71,7 @@ def main(): config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE, ) - train_dataset = MapDataset(root_dir="data/maps/train/",) + train_dataset = MapDataset(root_dir=config.TRAIN_DIR) train_loader = DataLoader( train_dataset, batch_size=config.BATCH_SIZE, @@ -80,7 +80,7 @@ def main(): ) g_scaler = torch.cuda.amp.GradScaler() d_scaler = torch.cuda.amp.GradScaler() - val_dataset = MapDataset(root_dir="data/maps/val/") + val_dataset = MapDataset(root_dir=config.VAL_DIR) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) for epoch in range(config.NUM_EPOCHS):