mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
cyclegan, progan
This commit is contained in:
@@ -44,7 +44,7 @@ class YoloLoss(nn.Module):
|
||||
anchors = anchors.reshape(1, 3, 1, 1, 2)
|
||||
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
|
||||
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
|
||||
object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj]))
|
||||
object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
|
||||
|
||||
# ======================== #
|
||||
# FOR BOX COORDINATES #
|
||||
@@ -64,12 +64,12 @@ class YoloLoss(nn.Module):
|
||||
(predictions[..., 5:][obj]), (target[..., 5][obj].long()),
|
||||
)
|
||||
|
||||
# print("__________________________________")
|
||||
# print(self.lambda_box * box_loss)
|
||||
# print(self.lambda_obj * object_loss)
|
||||
# print(self.lambda_noobj * no_object_loss)
|
||||
# print(self.lambda_class * class_loss)
|
||||
# print("\n")
|
||||
#print("__________________________________")
|
||||
#print(self.lambda_box * box_loss)
|
||||
#print(self.lambda_obj * object_loss)
|
||||
#print(self.lambda_noobj * no_object_loss)
|
||||
#print(self.lambda_class * class_loss)
|
||||
#print("\n")
|
||||
|
||||
return (
|
||||
self.lambda_box * box_loss
|
||||
|
||||
Reference in New Issue
Block a user