checked GAN code

This commit is contained in:
Aladdin Persson
2022-12-21 14:03:08 +01:00
parent b6985eccc9
commit c646ef65e2
14 changed files with 225 additions and 270 deletions

View File

@@ -1,3 +1,11 @@
"""
Training for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
from dataset import HorseZebraDataset
import sys
@@ -11,7 +19,10 @@ from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
def train_fn(
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
H_reals = 0
H_fakes = 0
loop = tqdm(loader, leave=True)
@@ -39,7 +50,7 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
# put it togethor
D_loss = (D_H_loss + D_Z_loss)/2
D_loss = (D_H_loss + D_Z_loss) / 2
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
@@ -82,11 +93,10 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
g_scaler.update()
if idx % 200 == 0:
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
def main():
@@ -111,23 +121,39 @@ def main():
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
config.CHECKPOINT_GEN_H,
gen_H,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
config.CHECKPOINT_GEN_Z,
gen_Z,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
config.CHECKPOINT_CRITIC_H,
disc_H,
opt_disc,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
config.CHECKPOINT_CRITIC_Z,
disc_Z,
opt_disc,
config.LEARNING_RATE,
)
dataset = HorseZebraDataset(
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
root_horse=config.TRAIN_DIR + "/horses",
root_zebra=config.TRAIN_DIR + "/zebras",
transform=config.transforms,
)
val_dataset = HorseZebraDataset(
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
root_horse="cyclegan_test/horse1",
root_zebra="cyclegan_test/zebra1",
transform=config.transforms,
)
val_loader = DataLoader(
val_dataset,
@@ -140,13 +166,25 @@ def main():
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True
pin_memory=True,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(config.NUM_EPOCHS):
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
train_fn(
disc_H,
disc_Z,
gen_Z,
gen_H,
loader,
opt_disc,
opt_gen,
L1,
mse,
d_scaler,
g_scaler,
)
if config.SAVE_MODEL:
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
@@ -154,5 +192,6 @@ def main():
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
if __name__ == "__main__":
main()
main()