cyclegan, progan

This commit is contained in:
Aladdin Persson
2021-03-11 15:50:44 +01:00
parent 91b1fd156c
commit 2c53205f12
27 changed files with 276 additions and 238 deletions

View File

@@ -44,7 +44,7 @@ class YoloLoss(nn.Module):
anchors = anchors.reshape(1, 3, 1, 1, 2)
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj]))
object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
# ======================== #
# FOR BOX COORDINATES #
@@ -64,12 +64,12 @@ class YoloLoss(nn.Module):
(predictions[..., 5:][obj]), (target[..., 5][obj].long()),
)
# print("__________________________________")
# print(self.lambda_box * box_loss)
# print(self.lambda_obj * object_loss)
# print(self.lambda_noobj * no_object_loss)
# print(self.lambda_class * class_loss)
# print("\n")
#print("__________________________________")
#print(self.lambda_box * box_loss)
#print(self.lambda_obj * object_loss)
#print(self.lambda_noobj * no_object_loss)
#print(self.lambda_class * class_loss)
#print("\n")
return (
self.lambda_box * box_loss