Files

125 lines
4.4 KiB
Python
Raw Permalink Normal View History

2021-01-30 21:49:15 +01:00
"""
Implementation of Yolo Loss Function from the original yolo paper
"""
import torch
import torch.nn as nn
from utils import intersection_over_union
class YoloLoss(nn.Module):
"""
Calculate the loss for yolo (v1) model
"""
def __init__(self, S=7, B=2, C=20):
super(YoloLoss, self).__init__()
self.mse = nn.MSELoss(reduction="sum")
"""
S is split size of image (in paper 7),
B is number of boxes (in paper 2),
C is number of classes (in paper and VOC dataset is 20),
"""
self.S = S
self.B = B
self.C = C
# These are from Yolo paper, signifying how much we should
# pay loss for no object (noobj) and the box coordinates (coord)
self.lambda_noobj = 0.5
self.lambda_coord = 5
def forward(self, predictions, target):
# predictions are shaped (BATCH_SIZE, S*S(C+B*5) when inputted
predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)
# Calculate IoU for the two predicted bounding boxes with target bbox
iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
# Take the box with highest IoU out of the two prediction
# Note that bestbox will be indices of 0, 1 for which bbox was best
iou_maxes, bestbox = torch.max(ious, dim=0)
exists_box = target[..., 20].unsqueeze(3) # in paper this is Iobj_i
# ======================== #
# FOR BOX COORDINATES #
# ======================== #
# Set boxes with no object in them to 0. We only take out one of the two
# predictions, which is the one with highest Iou calculated previously.
box_predictions = exists_box * (
(
bestbox * predictions[..., 26:30]
+ (1 - bestbox) * predictions[..., 21:25]
)
)
box_targets = exists_box * target[..., 21:25]
# Take sqrt of width, height of boxes to ensure that
box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
torch.abs(box_predictions[..., 2:4] + 1e-6)
)
box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
box_loss = self.mse(
torch.flatten(box_predictions, end_dim=-2),
torch.flatten(box_targets, end_dim=-2),
)
# ==================== #
# FOR OBJECT LOSS #
# ==================== #
# pred_box is the confidence score for the bbox with highest IoU
pred_box = (
bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21]
)
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * target[..., 20:21]),
)
# ======================= #
# FOR NO OBJECT LOSS #
# ======================= #
#max_no_obj = torch.max(predictions[..., 20:21], predictions[..., 25:26])
#no_object_loss = self.mse(
# torch.flatten((1 - exists_box) * max_no_obj, start_dim=1),
# torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
#)
no_object_loss = self.mse(
torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
)
no_object_loss += self.mse(
torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
)
# ================== #
# FOR CLASS LOSS #
# ================== #
class_loss = self.mse(
torch.flatten(exists_box * predictions[..., :20], end_dim=-2,),
torch.flatten(exists_box * target[..., :20], end_dim=-2,),
)
loss = (
self.lambda_coord * box_loss # first two rows in paper
+ object_loss # third row in paper
+ self.lambda_noobj * no_object_loss # forth row
+ class_loss # fifth row
)
return loss