mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
update readmes, added pix2pix
This commit is contained in:
@@ -2,7 +2,9 @@ import torch
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
TRAIN_DIR = "data/train"
|
||||
VAL_DIR = "data/val"
|
||||
LEARNING_RATE = 2e-4
|
||||
BATCH_SIZE = 16
|
||||
NUM_WORKERS = 2
|
||||
@@ -22,6 +24,8 @@ both_transform = A.Compose(
|
||||
|
||||
transform_only_input = A.Compose(
|
||||
[
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.ColorJitter(p=0.2),
|
||||
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
|
||||
ToTensorV2(),
|
||||
]
|
||||
|
||||
@@ -71,7 +71,7 @@ def main():
|
||||
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
train_dataset = MapDataset(root_dir="data/maps/train/",)
|
||||
train_dataset = MapDataset(root_dir=config.TRAIN_DIR)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.BATCH_SIZE,
|
||||
@@ -80,7 +80,7 @@ def main():
|
||||
)
|
||||
g_scaler = torch.cuda.amp.GradScaler()
|
||||
d_scaler = torch.cuda.amp.GradScaler()
|
||||
val_dataset = MapDataset(root_dir="data/maps/val/")
|
||||
val_dataset = MapDataset(root_dir=config.VAL_DIR)
|
||||
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
|
||||
|
||||
for epoch in range(config.NUM_EPOCHS):
|
||||
|
||||
Reference in New Issue
Block a user