mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
cyclegan, progan
This commit is contained in:
@@ -1,2 +0,0 @@
|
||||
Put images in images folder, text files for labels in labels folder.
|
||||
Then under COCO put train.csv, and test.csv
|
||||
@@ -1,2 +0,0 @@
|
||||
Put images in images folder, text files for labels in labels folder.
|
||||
Then under PASCAL_VOC put train.csv, and test.csv
|
||||
@@ -11,11 +11,11 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
NUM_WORKERS = 4
|
||||
BATCH_SIZE = 32
|
||||
IMAGE_SIZE = 416
|
||||
NUM_CLASSES = 80
|
||||
LEARNING_RATE = 3e-5
|
||||
NUM_CLASSES = 20
|
||||
LEARNING_RATE = 1e-5
|
||||
WEIGHT_DECAY = 1e-4
|
||||
NUM_EPOCHS = 100
|
||||
CONF_THRESHOLD = 0.6
|
||||
CONF_THRESHOLD = 0.05
|
||||
MAP_IOU_THRESH = 0.5
|
||||
NMS_IOU_THRESH = 0.45
|
||||
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
|
||||
@@ -47,9 +47,9 @@ train_transforms = A.Compose(
|
||||
A.OneOf(
|
||||
[
|
||||
A.ShiftScaleRotate(
|
||||
rotate_limit=10, p=0.4, border_mode=cv2.BORDER_CONSTANT
|
||||
rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
|
||||
),
|
||||
A.IAAAffine(shear=10, p=0.4, mode="constant"),
|
||||
A.IAAAffine(shear=15, p=0.5, mode="constant"),
|
||||
],
|
||||
p=1.0,
|
||||
),
|
||||
|
||||
@@ -44,7 +44,7 @@ class YoloLoss(nn.Module):
|
||||
anchors = anchors.reshape(1, 3, 1, 1, 2)
|
||||
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
|
||||
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
|
||||
object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj]))
|
||||
object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
|
||||
|
||||
# ======================== #
|
||||
# FOR BOX COORDINATES #
|
||||
@@ -64,12 +64,12 @@ class YoloLoss(nn.Module):
|
||||
(predictions[..., 5:][obj]), (target[..., 5][obj].long()),
|
||||
)
|
||||
|
||||
# print("__________________________________")
|
||||
# print(self.lambda_box * box_loss)
|
||||
# print(self.lambda_obj * object_loss)
|
||||
# print(self.lambda_noobj * no_object_loss)
|
||||
# print(self.lambda_class * class_loss)
|
||||
# print("\n")
|
||||
#print("__________________________________")
|
||||
#print(self.lambda_box * box_loss)
|
||||
#print(self.lambda_obj * object_loss)
|
||||
#print(self.lambda_noobj * no_object_loss)
|
||||
#print(self.lambda_class * class_loss)
|
||||
#print("\n")
|
||||
|
||||
return (
|
||||
self.lambda_box * box_loss
|
||||
|
||||
@@ -19,6 +19,8 @@ from utils import (
|
||||
plot_couple_examples
|
||||
)
|
||||
from loss import YoloLoss
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
@@ -80,19 +82,16 @@ def main():
|
||||
#plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
|
||||
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
|
||||
#if config.SAVE_MODEL:
|
||||
# save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
|
||||
|
||||
#print(f"Currently epoch {epoch}")
|
||||
#print("On Train Eval loader:")
|
||||
#check_class_accuracy(model, train_eval_loader, threshold=config.CONF_THRESHOLD)
|
||||
#print("On Train loader:")
|
||||
#check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)
|
||||
|
||||
if epoch % 10 == 0 and epoch > 0:
|
||||
print("On Test loader:")
|
||||
if epoch > 0 and epoch % 3 == 0:
|
||||
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
|
||||
|
||||
pred_boxes, true_boxes = get_evaluation_bboxes(
|
||||
test_loader,
|
||||
model,
|
||||
@@ -108,7 +107,7 @@ def main():
|
||||
num_classes=config.NUM_CLASSES,
|
||||
)
|
||||
print(f"MAP: {mapval.item()}")
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -380,8 +380,6 @@ def check_class_accuracy(model, loader, threshold):
|
||||
tot_obj, correct_obj = 0, 0
|
||||
|
||||
for idx, (x, y) in enumerate(tqdm(loader)):
|
||||
if idx == 100:
|
||||
break
|
||||
x = x.to(config.DEVICE)
|
||||
with torch.no_grad():
|
||||
out = model(x)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Download and put pretrained weights here!
|
||||
Reference in New Issue
Block a user