Files
Machine-Learning-Collection/ML/Kaggles/Dog vs Cat Competition/train.py
2021-05-27 10:21:14 +02:00

93 lines
3.0 KiB
Python

# Imports
import os
import torch
import torch.nn.functional as F
import numpy as np
import config
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import CatDog
from efficientnet_pytorch import EfficientNet
from utils import check_accuracy, load_checkpoint, save_checkpoint
def save_feature_vectors(model, loader, output_size=(1, 1), file="trainb7"):
model.eval()
images, labels = [], []
for idx, (x, y) in enumerate(tqdm(loader)):
x = x.to(config.DEVICE)
with torch.no_grad():
features = model.extract_features(x)
features = F.adaptive_avg_pool2d(features, output_size=output_size)
images.append(features.reshape(x.shape[0], -1).detach().cpu().numpy())
labels.append(y.numpy())
np.save(f"data_features/X_{file}.npy", np.concatenate(images, axis=0))
np.save(f"data_features/y_{file}.npy", np.concatenate(labels, axis=0))
model.train()
def train_one_epoch(loader, model, loss_fn, optimizer, scaler):
loop = tqdm(loader)
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(config.DEVICE)
targets = targets.to(config.DEVICE).unsqueeze(1).float()
with torch.cuda.amp.autocast():
scores = model(data)
loss = loss_fn(scores, targets)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
loop.set_postfix(loss=loss.item())
def main():
model = EfficientNet.from_pretrained("efficientnet-b7")
model._fc = nn.Linear(2560, 1)
train_dataset = CatDog(root="data/train/", transform=config.basic_transform)
test_dataset = CatDog(root="data/test/", transform=config.basic_transform)
train_loader = DataLoader(
train_dataset,
shuffle=True,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
test_loader = DataLoader(
test_dataset,
shuffle=False,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
)
model = model.to(config.DEVICE)
scaler = torch.cuda.amp.GradScaler()
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
)
if config.LOAD_MODEL and config.CHECKPOINT_FILE in os.listdir():
load_checkpoint(torch.load(config.CHECKPOINT_FILE), model)
for epoch in range(config.NUM_EPOCHS):
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler)
check_accuracy(train_loader, model, loss_fn)
if config.SAVE_MODEL:
checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
save_checkpoint(checkpoint, filename=config.CHECKPOINT_FILE)
save_feature_vectors(model, train_loader, output_size=(1, 1), file="train_b7")
save_feature_vectors(model, test_loader, output_size=(1, 1), file="test_b7")
if __name__ == "__main__":
main()