mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
27 lines
569 B
Python
27 lines
569 B
Python
import torch
|
|
import albumentations as A
|
|
from albumentations.pytorch import ToTensorV2
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
NUM_WORKERS = 4
|
|
BATCH_SIZE = 20
|
|
PIN_MEMORY = True
|
|
LOAD_MODEL = True
|
|
SAVE_MODEL = True
|
|
CHECKPOINT_FILE = "b7.pth.tar"
|
|
WEIGHT_DECAY = 1e-4
|
|
LEARNING_RATE = 1e-4
|
|
NUM_EPOCHS = 1
|
|
|
|
basic_transform = A.Compose(
|
|
[
|
|
A.Resize(height=448, width=448),
|
|
A.Normalize(
|
|
mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225],
|
|
max_pixel_value=255.0,
|
|
),
|
|
ToTensorV2(),
|
|
]
|
|
)
|