""" Training of WGAN-GP Programmed by Aladdin Persson * 2020-11-01: Initial coding * 2022-12-20: Small revision of code, checked that it works with latest PyTorch version """ import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.datasets as datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from utils import gradient_penalty, save_checkpoint, load_checkpoint from model import Discriminator, Generator, initialize_weights # Hyperparameters etc. device = "cuda" if torch.cuda.is_available() else "cpu" LEARNING_RATE = 1e-4 BATCH_SIZE = 64 IMAGE_SIZE = 64 CHANNELS_IMG = 1 Z_DIM = 100 NUM_EPOCHS = 100 FEATURES_CRITIC = 16 FEATURES_GEN = 16 CRITIC_ITERATIONS = 5 LAMBDA_GP = 10 transforms = transforms.Compose( [ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize( [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)] ), ] ) dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True) # comment mnist above and uncomment below for training on CelebA # dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms) loader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, ) # initialize gen and disc, note: discriminator should be called critic, # according to WGAN paper (since it no longer outputs between [0, 1]) gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device) critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device) initialize_weights(gen) initialize_weights(critic) # initializate optimizer opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9)) opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9)) # for tensorboard plotting fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device) writer_real = SummaryWriter(f"logs/GAN_MNIST/real") writer_fake = SummaryWriter(f"logs/GAN_MNIST/fake") step = 0 gen.train() critic.train() for epoch in range(NUM_EPOCHS): # Target labels not needed! <3 unsupervised for batch_idx, (real, _) in enumerate(tqdm(loader)): real = real.to(device) cur_batch_size = real.shape[0] # Train Critic: max E[critic(real)] - E[critic(fake)] # equivalent to minimizing the negative of that for _ in range(CRITIC_ITERATIONS): noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device) fake = gen(noise) critic_real = critic(real).reshape(-1) critic_fake = critic(fake).reshape(-1) gp = gradient_penalty(critic, real, fake, device=device) loss_critic = ( -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp ) critic.zero_grad() loss_critic.backward(retain_graph=True) opt_critic.step() # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)] gen_fake = critic(fake).reshape(-1) loss_gen = -torch.mean(gen_fake) gen.zero_grad() loss_gen.backward() opt_gen.step() # Print losses occasionally and print to tensorboard if batch_idx % 100 == 0 and batch_idx > 0: print( f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \ Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}" ) with torch.no_grad(): fake = gen(fixed_noise) # take out (up to) 32 examples img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True) img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True) writer_real.add_image("Real", img_grid_real, global_step=step) writer_fake.add_image("Fake", img_grid_fake, global_step=step) step += 1