mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
54 lines
2.0 KiB
Python
54 lines
2.0 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.optim as optim
|
||
|
|
from torch.utils.data import DataLoader
|
||
|
|
from tqdm import tqdm
|
||
|
|
from model import Net
|
||
|
|
from utils import check_accuracy, load_checkpoint, save_checkpoint, make_prediction
|
||
|
|
import config
|
||
|
|
from dataset import MyImageFolder
|
||
|
|
|
||
|
|
|
||
|
|
def train_fn(loader, model, optimizer, loss_fn, scaler, device):
|
||
|
|
for batch_idx, (data, targets) in enumerate(tqdm(loader)):
|
||
|
|
# Get data to cuda if possible
|
||
|
|
data = data.to(device=device)
|
||
|
|
targets = targets.to(device=device)
|
||
|
|
|
||
|
|
# forward
|
||
|
|
with torch.cuda.amp.autocast():
|
||
|
|
scores = model(data)
|
||
|
|
loss = loss_fn(scores, targets.float())
|
||
|
|
|
||
|
|
# backward
|
||
|
|
optimizer.zero_grad()
|
||
|
|
scaler.scale(loss).backward()
|
||
|
|
scaler.step(optimizer)
|
||
|
|
scaler.update()
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
train_ds = MyImageFolder(root_dir="train/", transform=config.train_transforms)
|
||
|
|
val_ds = MyImageFolder(root_dir="val/", transform=config.val_transforms)
|
||
|
|
train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS,pin_memory=config.PIN_MEMORY, shuffle=True)
|
||
|
|
val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS,pin_memory=config.PIN_MEMORY,shuffle=True)
|
||
|
|
|
||
|
|
loss_fn = nn.CrossEntropyLoss()
|
||
|
|
model = Net(net_version="b0", num_classes=10).to(config.DEVICE)
|
||
|
|
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
|
||
|
|
scaler = torch.cuda.amp.GradScaler()
|
||
|
|
|
||
|
|
if config.LOAD_MODEL:
|
||
|
|
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
|
||
|
|
|
||
|
|
make_prediction(model, config.val_transforms, 'test/', config.DEVICE)
|
||
|
|
check_accuracy(val_loader, model, config.DEVICE)
|
||
|
|
|
||
|
|
for epoch in range(config.NUM_EPOCHS):
|
||
|
|
train_fn(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE)
|
||
|
|
check_accuracy(val_loader, model, config.DEVICE)
|
||
|
|
checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
|
||
|
|
save_checkpoint(checkpoint)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|