update readmes, added pix2pix

This commit is contained in:
Aladdin Persson
2021-03-06 12:52:33 +01:00
parent 65a51c6e64
commit 946465e63c
2 changed files with 7 additions and 3 deletions

View File

@@ -2,7 +2,9 @@ import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/train"
VAL_DIR = "data/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
@@ -22,6 +24,8 @@ both_transform = A.Compose(
transform_only_input = A.Compose(
[
A.HorizontalFlip(p=0.5),
A.ColorJitter(p=0.2),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
ToTensorV2(),
]

View File

@@ -71,7 +71,7 @@ def main():
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
)
train_dataset = MapDataset(root_dir="data/maps/train/",)
train_dataset = MapDataset(root_dir=config.TRAIN_DIR)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
@@ -80,7 +80,7 @@ def main():
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
val_dataset = MapDataset(root_dir="data/maps/val/")
val_dataset = MapDataset(root_dir=config.VAL_DIR)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
for epoch in range(config.NUM_EPOCHS):