mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
49 lines
999 B
Python
49 lines
999 B
Python
|
|
import torch
|
||
|
|
from PIL import Image
|
||
|
|
import albumentations as A
|
||
|
|
from albumentations.pytorch import ToTensorV2
|
||
|
|
|
||
|
|
LOAD_MODEL = True
|
||
|
|
SAVE_MODEL = True
|
||
|
|
CHECKPOINT_GEN = "gen.pth"
|
||
|
|
CHECKPOINT_DISC = "disc.pth"
|
||
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
LEARNING_RATE = 1e-4
|
||
|
|
NUM_EPOCHS = 10000
|
||
|
|
BATCH_SIZE = 16
|
||
|
|
LAMBDA_GP = 10
|
||
|
|
NUM_WORKERS = 4
|
||
|
|
HIGH_RES = 128
|
||
|
|
LOW_RES = HIGH_RES // 4
|
||
|
|
IMG_CHANNELS = 3
|
||
|
|
|
||
|
|
highres_transform = A.Compose(
|
||
|
|
[
|
||
|
|
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||
|
|
ToTensorV2(),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
lowres_transform = A.Compose(
|
||
|
|
[
|
||
|
|
A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
|
||
|
|
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||
|
|
ToTensorV2(),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
both_transforms = A.Compose(
|
||
|
|
[
|
||
|
|
A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
|
||
|
|
A.HorizontalFlip(p=0.5),
|
||
|
|
A.RandomRotate90(p=0.5),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
test_transform = A.Compose(
|
||
|
|
[
|
||
|
|
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||
|
|
ToTensorV2(),
|
||
|
|
]
|
||
|
|
)
|