Files
Machine-Learning-Collection/ML/Pytorch/object_detection/YOLOv3/train.py
Aladdin Persson 4a6627a3bf add yolov3
2021-02-22 16:49:22 +01:00

116 lines
3.4 KiB
Python

"""
Main file for training Yolo model on Pascal VOC and COCO dataset
"""
import config
import torch
import torch.optim as optim
from model import YOLOv3
from tqdm import tqdm
from utils import (
mean_average_precision,
cells_to_bboxes,
get_evaluation_bboxes,
save_checkpoint,
load_checkpoint,
check_class_accuracy,
get_loaders,
plot_couple_examples
)
from loss import YoloLoss
torch.backends.cudnn.benchmark = True
def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
loop = tqdm(train_loader, leave=True)
losses = []
for batch_idx, (x, y) in enumerate(loop):
x = x.to(config.DEVICE)
y0, y1, y2 = (
y[0].to(config.DEVICE),
y[1].to(config.DEVICE),
y[2].to(config.DEVICE),
)
with torch.cuda.amp.autocast():
out = model(x)
loss = (
loss_fn(out[0], y0, scaled_anchors[0])
+ loss_fn(out[1], y1, scaled_anchors[1])
+ loss_fn(out[2], y2, scaled_anchors[2])
)
losses.append(loss.item())
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update progress bar
mean_loss = sum(losses) / len(losses)
loop.set_postfix(loss=mean_loss)
def main():
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
optimizer = optim.Adam(
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
)
loss_fn = YoloLoss()
scaler = torch.cuda.amp.GradScaler()
train_loader, test_loader, train_eval_loader = get_loaders(
train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv"
)
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
)
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
for epoch in range(config.NUM_EPOCHS):
#plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
if config.SAVE_MODEL:
save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
#print(f"Currently epoch {epoch}")
#print("On Train Eval loader:")
#check_class_accuracy(model, train_eval_loader, threshold=config.CONF_THRESHOLD)
#print("On Train loader:")
#check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)
if epoch % 10 == 0 and epoch > 0:
print("On Test loader:")
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
pred_boxes, true_boxes = get_evaluation_bboxes(
test_loader,
model,
iou_threshold=config.NMS_IOU_THRESH,
anchors=config.ANCHORS,
threshold=config.CONF_THRESHOLD,
)
mapval = mean_average_precision(
pred_boxes,
true_boxes,
iou_threshold=config.MAP_IOU_THRESH,
box_format="midpoint",
num_classes=config.NUM_CLASSES,
)
print(f"MAP: {mapval.item()}")
if __name__ == "__main__":
main()