mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
50 lines
2.1 KiB
Python
50 lines
2.1 KiB
Python
import torch
|
|
|
|
|
|
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
|
|
"""
|
|
Calculates intersection over union
|
|
|
|
Parameters:
|
|
boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
|
|
boxes_labels (tensor): Correct Labels of Boxes (BATCH_SIZE, 4)
|
|
box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
|
|
|
|
Returns:
|
|
tensor: Intersection over union for all examples
|
|
"""
|
|
|
|
# Slicing idx:idx+1 in order to keep tensor dimensionality
|
|
# Doing ... in indexing if there would be additional dimensions
|
|
# Like for Yolo algorithm which would have (N, S, S, 4) in shape
|
|
if box_format == "midpoint":
|
|
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
|
|
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
|
|
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
|
|
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
|
|
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
|
|
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
|
|
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
|
|
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
|
|
|
|
elif box_format == "corners":
|
|
box1_x1 = boxes_preds[..., 0:1]
|
|
box1_y1 = boxes_preds[..., 1:2]
|
|
box1_x2 = boxes_preds[..., 2:3]
|
|
box1_y2 = boxes_preds[..., 3:4]
|
|
box2_x1 = boxes_labels[..., 0:1]
|
|
box2_y1 = boxes_labels[..., 1:2]
|
|
box2_x2 = boxes_labels[..., 2:3]
|
|
box2_y2 = boxes_labels[..., 3:4]
|
|
|
|
x1 = torch.max(box1_x1, box2_x1)
|
|
y1 = torch.max(box1_y1, box2_y1)
|
|
x2 = torch.min(box1_x2, box2_x2)
|
|
y2 = torch.min(box1_y2, box2_y2)
|
|
|
|
# Need clamp(0) in case they do not intersect, then we want intersection to be 0
|
|
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
|
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
|
|
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
|
|
|
|
return intersection / (box1_area + box2_area - intersection + 1e-6) |