mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
update readmes, added pix2pix
This commit is contained in:
@@ -2,7 +2,9 @@ import torch
|
|||||||
import albumentations as A
|
import albumentations as A
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
|
||||||
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
TRAIN_DIR = "data/train"
|
||||||
|
VAL_DIR = "data/val"
|
||||||
LEARNING_RATE = 2e-4
|
LEARNING_RATE = 2e-4
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
NUM_WORKERS = 2
|
NUM_WORKERS = 2
|
||||||
@@ -22,6 +24,8 @@ both_transform = A.Compose(
|
|||||||
|
|
||||||
transform_only_input = A.Compose(
|
transform_only_input = A.Compose(
|
||||||
[
|
[
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
A.ColorJitter(p=0.2),
|
||||||
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
|
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def main():
|
|||||||
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
|
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = MapDataset(root_dir="data/maps/train/",)
|
train_dataset = MapDataset(root_dir=config.TRAIN_DIR)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.BATCH_SIZE,
|
batch_size=config.BATCH_SIZE,
|
||||||
@@ -80,7 +80,7 @@ def main():
|
|||||||
)
|
)
|
||||||
g_scaler = torch.cuda.amp.GradScaler()
|
g_scaler = torch.cuda.amp.GradScaler()
|
||||||
d_scaler = torch.cuda.amp.GradScaler()
|
d_scaler = torch.cuda.amp.GradScaler()
|
||||||
val_dataset = MapDataset(root_dir="data/maps/val/")
|
val_dataset = MapDataset(root_dir=config.VAL_DIR)
|
||||||
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
|
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
|
||||||
|
|
||||||
for epoch in range(config.NUM_EPOCHS):
|
for epoch in range(config.NUM_EPOCHS):
|
||||||
|
|||||||
Reference in New Issue
Block a user