mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
124 lines
3.1 KiB
Python
124 lines
3.1 KiB
Python
import torch
|
|
import albumentations as A
|
|
from albumentations.pytorch import ToTensorV2
|
|
from tqdm import tqdm
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from model import UNET
|
|
from utils import (
|
|
load_checkpoint,
|
|
save_checkpoint,
|
|
get_loaders,
|
|
check_accuracy,
|
|
save_predictions_as_imgs,
|
|
)
|
|
|
|
# Hyperparameters etc.
|
|
LEARNING_RATE = 1e-4
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
BATCH_SIZE = 16
|
|
NUM_EPOCHS = 3
|
|
NUM_WORKERS = 2
|
|
IMAGE_HEIGHT = 160 # 1280 originally
|
|
IMAGE_WIDTH = 240 # 1918 originally
|
|
PIN_MEMORY = True
|
|
LOAD_MODEL = False
|
|
TRAIN_IMG_DIR = "data/train_images/"
|
|
TRAIN_MASK_DIR = "data/train_masks/"
|
|
VAL_IMG_DIR = "data/val_images/"
|
|
VAL_MASK_DIR = "data/val_masks/"
|
|
|
|
def train_fn(loader, model, optimizer, loss_fn, scaler):
|
|
loop = tqdm(loader)
|
|
|
|
for batch_idx, (data, targets) in enumerate(loop):
|
|
data = data.to(device=DEVICE)
|
|
targets = targets.float().unsqueeze(1).to(device=DEVICE)
|
|
|
|
# forward
|
|
with torch.cuda.amp.autocast():
|
|
predictions = model(data)
|
|
loss = loss_fn(predictions, targets)
|
|
|
|
# backward
|
|
optimizer.zero_grad()
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
# update tqdm loop
|
|
loop.set_postfix(loss=loss.item())
|
|
|
|
|
|
def main():
|
|
train_transform = A.Compose(
|
|
[
|
|
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
|
|
A.Rotate(limit=35, p=1.0),
|
|
A.HorizontalFlip(p=0.5),
|
|
A.VerticalFlip(p=0.1),
|
|
A.Normalize(
|
|
mean=[0.0, 0.0, 0.0],
|
|
std=[1.0, 1.0, 1.0],
|
|
max_pixel_value=255.0,
|
|
),
|
|
ToTensorV2(),
|
|
],
|
|
)
|
|
|
|
val_transforms = A.Compose(
|
|
[
|
|
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
|
|
A.Normalize(
|
|
mean=[0.0, 0.0, 0.0],
|
|
std=[1.0, 1.0, 1.0],
|
|
max_pixel_value=255.0,
|
|
),
|
|
ToTensorV2(),
|
|
],
|
|
)
|
|
|
|
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
|
|
loss_fn = nn.BCEWithLogitsLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
train_loader, val_loader = get_loaders(
|
|
TRAIN_IMG_DIR,
|
|
TRAIN_MASK_DIR,
|
|
VAL_IMG_DIR,
|
|
VAL_MASK_DIR,
|
|
BATCH_SIZE,
|
|
train_transform,
|
|
val_transforms,
|
|
NUM_WORKERS,
|
|
PIN_MEMORY,
|
|
)
|
|
|
|
if LOAD_MODEL:
|
|
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
|
|
|
|
|
|
check_accuracy(val_loader, model, device=DEVICE)
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
for epoch in range(NUM_EPOCHS):
|
|
train_fn(train_loader, model, optimizer, loss_fn, scaler)
|
|
|
|
# save model
|
|
checkpoint = {
|
|
"state_dict": model.state_dict(),
|
|
"optimizer":optimizer.state_dict(),
|
|
}
|
|
save_checkpoint(checkpoint)
|
|
|
|
# check accuracy
|
|
check_accuracy(val_loader, model, device=DEVICE)
|
|
|
|
# print some examples to a folder
|
|
save_predictions_as_imgs(
|
|
val_loader, model, folder="saved_images/", device=DEVICE
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |