add yolov3

This commit is contained in:
Aladdin Persson
2021-02-22 16:49:22 +01:00
parent bd4f07f5ed
commit 4a6627a3bf
10 changed files with 1271 additions and 0 deletions

View File

@@ -0,0 +1,115 @@
"""
Main file for training Yolo model on Pascal VOC and COCO dataset
"""
import config
import torch
import torch.optim as optim
from model import YOLOv3
from tqdm import tqdm
from utils import (
mean_average_precision,
cells_to_bboxes,
get_evaluation_bboxes,
save_checkpoint,
load_checkpoint,
check_class_accuracy,
get_loaders,
plot_couple_examples
)
from loss import YoloLoss
torch.backends.cudnn.benchmark = True
def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
loop = tqdm(train_loader, leave=True)
losses = []
for batch_idx, (x, y) in enumerate(loop):
x = x.to(config.DEVICE)
y0, y1, y2 = (
y[0].to(config.DEVICE),
y[1].to(config.DEVICE),
y[2].to(config.DEVICE),
)
with torch.cuda.amp.autocast():
out = model(x)
loss = (
loss_fn(out[0], y0, scaled_anchors[0])
+ loss_fn(out[1], y1, scaled_anchors[1])
+ loss_fn(out[2], y2, scaled_anchors[2])
)
losses.append(loss.item())
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update progress bar
mean_loss = sum(losses) / len(losses)
loop.set_postfix(loss=mean_loss)
def main():
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
optimizer = optim.Adam(
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
)
loss_fn = YoloLoss()
scaler = torch.cuda.amp.GradScaler()
train_loader, test_loader, train_eval_loader = get_loaders(
train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv"
)
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
)
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
for epoch in range(config.NUM_EPOCHS):
#plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
if config.SAVE_MODEL:
save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
#print(f"Currently epoch {epoch}")
#print("On Train Eval loader:")
#check_class_accuracy(model, train_eval_loader, threshold=config.CONF_THRESHOLD)
#print("On Train loader:")
#check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)
if epoch % 10 == 0 and epoch > 0:
print("On Test loader:")
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
pred_boxes, true_boxes = get_evaluation_bboxes(
test_loader,
model,
iou_threshold=config.NMS_IOU_THRESH,
anchors=config.ANCHORS,
threshold=config.CONF_THRESHOLD,
)
mapval = mean_average_precision(
pred_boxes,
true_boxes,
iou_threshold=config.MAP_IOU_THRESH,
box_format="midpoint",
num_classes=config.NUM_CLASSES,
)
print(f"MAP: {mapval.item()}")
if __name__ == "__main__":
main()