mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
checked GAN code
This commit is contained in:
@@ -1,3 +1,11 @@
|
||||
"""
|
||||
Training for CycleGAN
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-05: Initial coding
|
||||
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
from dataset import HorseZebraDataset
|
||||
import sys
|
||||
@@ -11,7 +19,10 @@ from torchvision.utils import save_image
|
||||
from discriminator_model import Discriminator
|
||||
from generator_model import Generator
|
||||
|
||||
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
|
||||
|
||||
def train_fn(
|
||||
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
|
||||
):
|
||||
H_reals = 0
|
||||
H_fakes = 0
|
||||
loop = tqdm(loader, leave=True)
|
||||
@@ -39,7 +50,7 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
|
||||
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
||||
|
||||
# put it togethor
|
||||
D_loss = (D_H_loss + D_Z_loss)/2
|
||||
D_loss = (D_H_loss + D_Z_loss) / 2
|
||||
|
||||
opt_disc.zero_grad()
|
||||
d_scaler.scale(D_loss).backward()
|
||||
@@ -82,11 +93,10 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
|
||||
g_scaler.update()
|
||||
|
||||
if idx % 200 == 0:
|
||||
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
|
||||
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
|
||||
|
||||
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
|
||||
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
|
||||
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
|
||||
|
||||
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
|
||||
|
||||
|
||||
def main():
|
||||
@@ -111,23 +121,39 @@ def main():
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_GEN_H,
|
||||
gen_H,
|
||||
opt_gen,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_GEN_Z,
|
||||
gen_Z,
|
||||
opt_gen,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_CRITIC_H,
|
||||
disc_H,
|
||||
opt_disc,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_CRITIC_Z,
|
||||
disc_Z,
|
||||
opt_disc,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
dataset = HorseZebraDataset(
|
||||
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
|
||||
root_horse=config.TRAIN_DIR + "/horses",
|
||||
root_zebra=config.TRAIN_DIR + "/zebras",
|
||||
transform=config.transforms,
|
||||
)
|
||||
val_dataset = HorseZebraDataset(
|
||||
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
|
||||
root_horse="cyclegan_test/horse1",
|
||||
root_zebra="cyclegan_test/zebra1",
|
||||
transform=config.transforms,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
@@ -140,13 +166,25 @@ def main():
|
||||
batch_size=config.BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
pin_memory=True
|
||||
pin_memory=True,
|
||||
)
|
||||
g_scaler = torch.cuda.amp.GradScaler()
|
||||
d_scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
for epoch in range(config.NUM_EPOCHS):
|
||||
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
|
||||
train_fn(
|
||||
disc_H,
|
||||
disc_Z,
|
||||
gen_Z,
|
||||
gen_H,
|
||||
loader,
|
||||
opt_disc,
|
||||
opt_gen,
|
||||
L1,
|
||||
mse,
|
||||
d_scaler,
|
||||
g_scaler,
|
||||
)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
|
||||
@@ -154,5 +192,6 @@ def main():
|
||||
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
|
||||
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user