mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
Initial commit
This commit is contained in:
148
ML/Pytorch/object_detection/YOLO/train.py
Executable file
148
ML/Pytorch/object_detection/YOLO/train.py
Executable file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Main file for training Yolo model on Pascal VOC dataset
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms.functional as FT
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
from model import Yolov1
|
||||
from dataset import VOCDataset
|
||||
from utils import (
|
||||
non_max_suppression,
|
||||
mean_average_precision,
|
||||
intersection_over_union,
|
||||
cellboxes_to_boxes,
|
||||
get_bboxes,
|
||||
plot_image,
|
||||
save_checkpoint,
|
||||
load_checkpoint,
|
||||
)
|
||||
from loss import YoloLoss
|
||||
|
||||
seed = 123
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Hyperparameters etc.
|
||||
LEARNING_RATE = 2e-5
|
||||
DEVICE = "cuda" if torch.cuda.is_available else "cpu"
|
||||
BATCH_SIZE = 16 # 64 in original paper but I don't have that much vram, grad accum?
|
||||
WEIGHT_DECAY = 0
|
||||
EPOCHS = 1000
|
||||
NUM_WORKERS = 2
|
||||
PIN_MEMORY = True
|
||||
LOAD_MODEL = False
|
||||
LOAD_MODEL_FILE = "overfit.pth.tar"
|
||||
IMG_DIR = "data/images"
|
||||
LABEL_DIR = "data/labels"
|
||||
|
||||
|
||||
class Compose(object):
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img, bboxes):
|
||||
for t in self.transforms:
|
||||
img, bboxes = t(img), bboxes
|
||||
|
||||
return img, bboxes
|
||||
|
||||
|
||||
transform = Compose([transforms.Resize((448, 448)), transforms.ToTensor(),])
|
||||
|
||||
|
||||
def train_fn(train_loader, model, optimizer, loss_fn):
|
||||
loop = tqdm(train_loader, leave=True)
|
||||
mean_loss = []
|
||||
|
||||
for batch_idx, (x, y) in enumerate(loop):
|
||||
x, y = x.to(DEVICE), y.to(DEVICE)
|
||||
out = model(x)
|
||||
loss = loss_fn(out, y)
|
||||
mean_loss.append(loss.item())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# update progress bar
|
||||
loop.set_postfix(loss=loss.item())
|
||||
|
||||
print(f"Mean loss was {sum(mean_loss)/len(mean_loss)}")
|
||||
|
||||
|
||||
def main():
|
||||
model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
|
||||
)
|
||||
loss_fn = YoloLoss()
|
||||
|
||||
if LOAD_MODEL:
|
||||
load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)
|
||||
|
||||
train_dataset = VOCDataset(
|
||||
"data/100examples.csv",
|
||||
transform=transform,
|
||||
img_dir=IMG_DIR,
|
||||
label_dir=LABEL_DIR,
|
||||
)
|
||||
|
||||
test_dataset = VOCDataset(
|
||||
"data/test.csv", transform=transform, img_dir=IMG_DIR, label_dir=LABEL_DIR,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset=train_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=PIN_MEMORY,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
dataset=test_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=PIN_MEMORY,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
# for x, y in train_loader:
|
||||
# x = x.to(DEVICE)
|
||||
# for idx in range(8):
|
||||
# bboxes = cellboxes_to_boxes(model(x))
|
||||
# bboxes = non_max_suppression(bboxes[idx], iou_threshold=0.5, threshold=0.4, box_format="midpoint")
|
||||
# plot_image(x[idx].permute(1,2,0).to("cpu"), bboxes)
|
||||
|
||||
# import sys
|
||||
# sys.exit()
|
||||
|
||||
pred_boxes, target_boxes = get_bboxes(
|
||||
train_loader, model, iou_threshold=0.5, threshold=0.4
|
||||
)
|
||||
|
||||
mean_avg_prec = mean_average_precision(
|
||||
pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
|
||||
)
|
||||
print(f"Train mAP: {mean_avg_prec}")
|
||||
|
||||
#if mean_avg_prec > 0.9:
|
||||
# checkpoint = {
|
||||
# "state_dict": model.state_dict(),
|
||||
# "optimizer": optimizer.state_dict(),
|
||||
# }
|
||||
# save_checkpoint(checkpoint, filename=LOAD_MODEL_FILE)
|
||||
# import time
|
||||
# time.sleep(10)
|
||||
|
||||
train_fn(train_loader, model, optimizer, loss_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user