mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
125 lines
4.4 KiB
Python
Executable File
125 lines
4.4 KiB
Python
Executable File
"""
|
|
Implementation of Yolo Loss Function from the original yolo paper
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from utils import intersection_over_union
|
|
|
|
|
|
class YoloLoss(nn.Module):
|
|
"""
|
|
Calculate the loss for yolo (v1) model
|
|
"""
|
|
|
|
def __init__(self, S=7, B=2, C=20):
|
|
super(YoloLoss, self).__init__()
|
|
self.mse = nn.MSELoss(reduction="sum")
|
|
|
|
"""
|
|
S is split size of image (in paper 7),
|
|
B is number of boxes (in paper 2),
|
|
C is number of classes (in paper and VOC dataset is 20),
|
|
"""
|
|
self.S = S
|
|
self.B = B
|
|
self.C = C
|
|
|
|
# These are from Yolo paper, signifying how much we should
|
|
# pay loss for no object (noobj) and the box coordinates (coord)
|
|
self.lambda_noobj = 0.5
|
|
self.lambda_coord = 5
|
|
|
|
def forward(self, predictions, target):
|
|
# predictions are shaped (BATCH_SIZE, S*S(C+B*5) when inputted
|
|
predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)
|
|
|
|
# Calculate IoU for the two predicted bounding boxes with target bbox
|
|
iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
|
|
iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
|
|
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
|
|
|
|
# Take the box with highest IoU out of the two prediction
|
|
# Note that bestbox will be indices of 0, 1 for which bbox was best
|
|
iou_maxes, bestbox = torch.max(ious, dim=0)
|
|
exists_box = target[..., 20].unsqueeze(3) # in paper this is Iobj_i
|
|
|
|
# ======================== #
|
|
# FOR BOX COORDINATES #
|
|
# ======================== #
|
|
|
|
# Set boxes with no object in them to 0. We only take out one of the two
|
|
# predictions, which is the one with highest Iou calculated previously.
|
|
box_predictions = exists_box * (
|
|
(
|
|
bestbox * predictions[..., 26:30]
|
|
+ (1 - bestbox) * predictions[..., 21:25]
|
|
)
|
|
)
|
|
|
|
box_targets = exists_box * target[..., 21:25]
|
|
|
|
# Take sqrt of width, height of boxes to ensure that
|
|
box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
|
|
torch.abs(box_predictions[..., 2:4] + 1e-6)
|
|
)
|
|
box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
|
|
|
|
box_loss = self.mse(
|
|
torch.flatten(box_predictions, end_dim=-2),
|
|
torch.flatten(box_targets, end_dim=-2),
|
|
)
|
|
|
|
# ==================== #
|
|
# FOR OBJECT LOSS #
|
|
# ==================== #
|
|
|
|
# pred_box is the confidence score for the bbox with highest IoU
|
|
pred_box = (
|
|
bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21]
|
|
)
|
|
|
|
object_loss = self.mse(
|
|
torch.flatten(exists_box * pred_box),
|
|
torch.flatten(exists_box * target[..., 20:21]),
|
|
)
|
|
|
|
# ======================= #
|
|
# FOR NO OBJECT LOSS #
|
|
# ======================= #
|
|
|
|
#max_no_obj = torch.max(predictions[..., 20:21], predictions[..., 25:26])
|
|
#no_object_loss = self.mse(
|
|
# torch.flatten((1 - exists_box) * max_no_obj, start_dim=1),
|
|
# torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
|
|
#)
|
|
|
|
no_object_loss = self.mse(
|
|
torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
|
|
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
|
|
)
|
|
|
|
no_object_loss += self.mse(
|
|
torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
|
|
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
|
|
)
|
|
|
|
# ================== #
|
|
# FOR CLASS LOSS #
|
|
# ================== #
|
|
|
|
class_loss = self.mse(
|
|
torch.flatten(exists_box * predictions[..., :20], end_dim=-2,),
|
|
torch.flatten(exists_box * target[..., :20], end_dim=-2,),
|
|
)
|
|
|
|
loss = (
|
|
self.lambda_coord * box_loss # first two rows in paper
|
|
+ object_loss # third row in paper
|
|
+ self.lambda_noobj * no_object_loss # forth row
|
|
+ class_loss # fifth row
|
|
)
|
|
|
|
return loss
|