Files
Machine-Learning-Collection/ML/Pytorch/GANs/CycleGAN/train.py

198 lines
6.0 KiB
Python
Raw Normal View History

2022-12-21 14:03:08 +01:00
"""
Training for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
2021-03-06 21:09:08 +01:00
import torch
from dataset import HorseZebraDataset
import sys
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import config
from tqdm import tqdm
from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator
2022-12-21 14:03:08 +01:00
def train_fn(
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
2021-03-06 21:09:08 +01:00
H_reals = 0
H_fakes = 0
loop = tqdm(loader, leave=True)
for idx, (zebra, horse) in enumerate(loop):
zebra = zebra.to(config.DEVICE)
horse = horse.to(config.DEVICE)
# Train Discriminators H and Z
with torch.cuda.amp.autocast():
fake_horse = gen_H(zebra)
D_H_real = disc_H(horse)
D_H_fake = disc_H(fake_horse.detach())
H_reals += D_H_real.mean().item()
H_fakes += D_H_fake.mean().item()
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
D_H_loss = D_H_real_loss + D_H_fake_loss
fake_zebra = gen_Z(horse)
D_Z_real = disc_Z(zebra)
D_Z_fake = disc_Z(fake_zebra.detach())
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
# put it togethor
2022-12-21 14:03:08 +01:00
D_loss = (D_H_loss + D_Z_loss) / 2
2021-03-06 21:09:08 +01:00
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()
# Train Generators H and Z
with torch.cuda.amp.autocast():
# adversarial loss for both generators
D_H_fake = disc_H(fake_horse)
D_Z_fake = disc_Z(fake_zebra)
loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
# cycle loss
cycle_zebra = gen_Z(fake_horse)
cycle_horse = gen_H(fake_zebra)
cycle_zebra_loss = l1(zebra, cycle_zebra)
cycle_horse_loss = l1(horse, cycle_horse)
# identity loss (remove these for efficiency if you set lambda_identity=0)
identity_zebra = gen_Z(zebra)
identity_horse = gen_H(horse)
identity_zebra_loss = l1(zebra, identity_zebra)
identity_horse_loss = l1(horse, identity_horse)
# add all togethor
G_loss = (
loss_G_Z
+ loss_G_H
+ cycle_zebra_loss * config.LAMBDA_CYCLE
+ cycle_horse_loss * config.LAMBDA_CYCLE
+ identity_horse_loss * config.LAMBDA_IDENTITY
+ identity_zebra_loss * config.LAMBDA_IDENTITY
)
opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
if idx % 200 == 0:
2022-12-21 14:03:08 +01:00
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
2021-03-06 21:09:08 +01:00
2022-12-21 14:03:08 +01:00
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
2021-03-06 21:09:08 +01:00
def main():
disc_H = Discriminator(in_channels=3).to(config.DEVICE)
disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
opt_disc = optim.Adam(
list(disc_H.parameters()) + list(disc_Z.parameters()),
lr=config.LEARNING_RATE,
betas=(0.5, 0.999),
)
opt_gen = optim.Adam(
list(gen_Z.parameters()) + list(gen_H.parameters()),
lr=config.LEARNING_RATE,
betas=(0.5, 0.999),
)
L1 = nn.L1Loss()
mse = nn.MSELoss()
if config.LOAD_MODEL:
load_checkpoint(
2022-12-21 14:03:08 +01:00
config.CHECKPOINT_GEN_H,
gen_H,
opt_gen,
config.LEARNING_RATE,
2021-03-06 21:09:08 +01:00
)
load_checkpoint(
2022-12-21 14:03:08 +01:00
config.CHECKPOINT_GEN_Z,
gen_Z,
opt_gen,
config.LEARNING_RATE,
2021-03-06 21:09:08 +01:00
)
load_checkpoint(
2022-12-21 14:03:08 +01:00
config.CHECKPOINT_CRITIC_H,
disc_H,
opt_disc,
config.LEARNING_RATE,
2021-03-06 21:09:08 +01:00
)
load_checkpoint(
2022-12-21 14:03:08 +01:00
config.CHECKPOINT_CRITIC_Z,
disc_Z,
opt_disc,
config.LEARNING_RATE,
2021-03-06 21:09:08 +01:00
)
dataset = HorseZebraDataset(
2022-12-21 14:03:08 +01:00
root_horse=config.TRAIN_DIR + "/horses",
root_zebra=config.TRAIN_DIR + "/zebras",
transform=config.transforms,
2021-03-06 21:09:08 +01:00
)
val_dataset = HorseZebraDataset(
2022-12-21 14:03:08 +01:00
root_horse="cyclegan_test/horse1",
root_zebra="cyclegan_test/zebra1",
transform=config.transforms,
2021-03-06 21:09:08 +01:00
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
)
loader = DataLoader(
dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
2022-12-21 14:03:08 +01:00
pin_memory=True,
2021-03-06 21:09:08 +01:00
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(config.NUM_EPOCHS):
2022-12-21 14:03:08 +01:00
train_fn(
disc_H,
disc_Z,
gen_Z,
gen_H,
loader,
opt_disc,
opt_gen,
L1,
mse,
d_scaler,
g_scaler,
)
2021-03-06 21:09:08 +01:00
if config.SAVE_MODEL:
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
save_checkpoint(gen_Z, opt_gen, filename=config.CHECKPOINT_GEN_Z)
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
2022-12-21 14:03:08 +01:00
2021-03-06 21:09:08 +01:00
if __name__ == "__main__":
2022-12-21 14:03:08 +01:00
main()