import torch import torchvision import torch.nn as nn # Print losses occasionally and print to tensorboard def plot_to_tensorboard( writer, loss_critic, loss_gen, real, fake, tensorboard_step ): writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step) with torch.no_grad(): # take out (up to) 32 examples img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True) img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True) writer.add_image("Real", img_grid_real, global_step=tensorboard_step) writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step) def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"): BATCH_SIZE, C, H, W = real.shape beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) interpolated_images = real * beta + fake * (1 - beta) # Calculate critic scores mixed_scores = critic(interpolated_images, alpha, train_step) # Take the gradient of the scores with respect to the images gradient = torch.autograd.grad( inputs=interpolated_images, outputs=mixed_scores, grad_outputs=torch.ones_like(mixed_scores), create_graph=True, retain_graph=True, )[0] gradient = gradient.view(gradient.shape[0], -1) gradient_norm = gradient.norm(2, dim=1) gradient_penalty = torch.mean((gradient_norm - 1) ** 2) return gradient_penalty def save_checkpoint(state, filename="celeba_wgan_gp.pth.tar"): print("=> Saving checkpoint") torch.save(state, filename) def load_checkpoint(checkpoint, gen, disc, opt_gen=None, opt_disc=None): print("=> Loading checkpoint") gen.load_state_dict(checkpoint['gen']) disc.load_state_dict(checkpoint['critic']) if opt_gen != None and opt_disc != None: opt_gen.load_state_dict(checkpoint['opt_gen']) opt_disc.load_state_dict(checkpoint['opt_critic'])