mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
add yolov3
This commit is contained in:
115
ML/Pytorch/object_detection/YOLOv3/train.py
Normal file
115
ML/Pytorch/object_detection/YOLOv3/train.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Main file for training Yolo model on Pascal VOC and COCO dataset
|
||||
"""
|
||||
|
||||
import config
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from model import YOLOv3
|
||||
from tqdm import tqdm
|
||||
from utils import (
|
||||
mean_average_precision,
|
||||
cells_to_bboxes,
|
||||
get_evaluation_bboxes,
|
||||
save_checkpoint,
|
||||
load_checkpoint,
|
||||
check_class_accuracy,
|
||||
get_loaders,
|
||||
plot_couple_examples
|
||||
)
|
||||
from loss import YoloLoss
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
|
||||
loop = tqdm(train_loader, leave=True)
|
||||
losses = []
|
||||
for batch_idx, (x, y) in enumerate(loop):
|
||||
x = x.to(config.DEVICE)
|
||||
y0, y1, y2 = (
|
||||
y[0].to(config.DEVICE),
|
||||
y[1].to(config.DEVICE),
|
||||
y[2].to(config.DEVICE),
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
out = model(x)
|
||||
loss = (
|
||||
loss_fn(out[0], y0, scaled_anchors[0])
|
||||
+ loss_fn(out[1], y1, scaled_anchors[1])
|
||||
+ loss_fn(out[2], y2, scaled_anchors[2])
|
||||
)
|
||||
|
||||
losses.append(loss.item())
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
# update progress bar
|
||||
mean_loss = sum(losses) / len(losses)
|
||||
loop.set_postfix(loss=mean_loss)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
|
||||
)
|
||||
loss_fn = YoloLoss()
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
train_loader, test_loader, train_eval_loader = get_loaders(
|
||||
train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv"
|
||||
)
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
|
||||
)
|
||||
|
||||
scaled_anchors = (
|
||||
torch.tensor(config.ANCHORS)
|
||||
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
|
||||
).to(config.DEVICE)
|
||||
|
||||
for epoch in range(config.NUM_EPOCHS):
|
||||
#plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
|
||||
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
|
||||
|
||||
#print(f"Currently epoch {epoch}")
|
||||
#print("On Train Eval loader:")
|
||||
#check_class_accuracy(model, train_eval_loader, threshold=config.CONF_THRESHOLD)
|
||||
#print("On Train loader:")
|
||||
#check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)
|
||||
|
||||
if epoch % 10 == 0 and epoch > 0:
|
||||
print("On Test loader:")
|
||||
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
|
||||
|
||||
pred_boxes, true_boxes = get_evaluation_bboxes(
|
||||
test_loader,
|
||||
model,
|
||||
iou_threshold=config.NMS_IOU_THRESH,
|
||||
anchors=config.ANCHORS,
|
||||
threshold=config.CONF_THRESHOLD,
|
||||
)
|
||||
mapval = mean_average_precision(
|
||||
pred_boxes,
|
||||
true_boxes,
|
||||
iou_threshold=config.MAP_IOU_THRESH,
|
||||
box_format="midpoint",
|
||||
num_classes=config.NUM_CLASSES,
|
||||
)
|
||||
print(f"MAP: {mapval.item()}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user