mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torchvision
|
||||
from dataset import CarvanaDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
torch.save(state, filename)
|
||||
|
||||
def load_checkpoint(checkpoint, model):
|
||||
print("=> Loading checkpoint")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
def get_loaders(
|
||||
train_dir,
|
||||
train_maskdir,
|
||||
val_dir,
|
||||
val_maskdir,
|
||||
batch_size,
|
||||
train_transform,
|
||||
val_transform,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
):
|
||||
train_ds = CarvanaDataset(
|
||||
image_dir=train_dir,
|
||||
mask_dir=train_maskdir,
|
||||
transform=train_transform,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
val_ds = CarvanaDataset(
|
||||
image_dir=val_dir,
|
||||
mask_dir=val_maskdir,
|
||||
transform=val_transform,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_ds,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
def check_accuracy(loader, model, device="cuda"):
|
||||
num_correct = 0
|
||||
num_pixels = 0
|
||||
dice_score = 0
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for x, y in loader:
|
||||
x = x.to(device)
|
||||
y = y.to(device).unsqueeze(1)
|
||||
preds = torch.sigmoid(model(x))
|
||||
preds = (preds > 0.5).float()
|
||||
num_correct += (preds == y).sum()
|
||||
num_pixels += torch.numel(preds)
|
||||
dice_score += (2 * (preds * y).sum()) / (
|
||||
(preds + y).sum() + 1e-8
|
||||
)
|
||||
|
||||
print(
|
||||
f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
|
||||
)
|
||||
print(f"Dice score: {dice_score/len(loader)}")
|
||||
model.train()
|
||||
|
||||
def save_predictions_as_imgs(
|
||||
loader, model, folder="saved_images/", device="cuda"
|
||||
):
|
||||
model.eval()
|
||||
for idx, (x, y) in enumerate(loader):
|
||||
x = x.to(device=device)
|
||||
with torch.no_grad():
|
||||
preds = torch.sigmoid(model(x))
|
||||
preds = (preds > 0.5).float()
|
||||
torchvision.utils.save_image(
|
||||
preds, f"{folder}/pred_{idx}.png"
|
||||
)
|
||||
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
|
||||
|
||||
model.train()
|
||||
Reference in New Issue
Block a user