""" Training of DCGAN network with WGAN loss Programmed by Aladdin Persson * 2020-11-01: Initial coding * 2022-12-20: Small revision of code, checked that it works with latest PyTorch version """ import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.datasets as datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter from model import Discriminator, Generator, initialize_weights # Hyperparameters etc device = "cuda" if torch.cuda.is_available() else "cpu" LEARNING_RATE = 5e-5 BATCH_SIZE = 64 IMAGE_SIZE = 64 CHANNELS_IMG = 1 Z_DIM = 128 NUM_EPOCHS = 5 FEATURES_CRITIC = 64 FEATURES_GEN = 64 CRITIC_ITERATIONS = 5 WEIGHT_CLIP = 0.01 transforms = transforms.Compose( [ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize( [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)] ), ] ) dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True) #comment mnist and uncomment below if you want to train on CelebA dataset #dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # initialize gen and disc/critic gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device) critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device) initialize_weights(gen) initialize_weights(critic) # initializate optimizer opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE) opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE) # for tensorboard plotting fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device) writer_real = SummaryWriter(f"logs/real") writer_fake = SummaryWriter(f"logs/fake") step = 0 gen.train() critic.train() for epoch in range(NUM_EPOCHS): # Target labels not needed! <3 unsupervised for batch_idx, (data, _) in enumerate(tqdm(loader)): data = data.to(device) cur_batch_size = data.shape[0] # Train Critic: max E[critic(real)] - E[critic(fake)] for _ in range(CRITIC_ITERATIONS): noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device) fake = gen(noise) critic_real = critic(data).reshape(-1) critic_fake = critic(fake).reshape(-1) loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) critic.zero_grad() loss_critic.backward(retain_graph=True) opt_critic.step() # clip critic weights between -0.01, 0.01 for p in critic.parameters(): p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP) # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)] gen_fake = critic(fake).reshape(-1) loss_gen = -torch.mean(gen_fake) gen.zero_grad() loss_gen.backward() opt_gen.step() # Print losses occasionally and print to tensorboard if batch_idx % 100 == 0 and batch_idx > 0: gen.eval() critic.eval() print( f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \ Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}" ) with torch.no_grad(): fake = gen(noise) # take out (up to) 32 examples img_grid_real = torchvision.utils.make_grid( data[:32], normalize=True ) img_grid_fake = torchvision.utils.make_grid( fake[:32], normalize=True ) writer_real.add_image("Real", img_grid_real, global_step=step) writer_fake.add_image("Fake", img_grid_fake, global_step=step) step += 1 gen.train() critic.train()