import torch import random import numpy as np import os import torchvision import torch.nn as nn # Print losses occasionally and print to tensorboard def plot_to_tensorboard( writer, loss_critic, loss_gen, real, fake, tensorboard_step ): writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step) with torch.no_grad(): # take out (up to) 32 examples img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True) img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True) writer.add_image("Real", img_grid_real, global_step=tensorboard_step) writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step) def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"): BATCH_SIZE, C, H, W = real.shape beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) interpolated_images = real * beta + fake.detach() * (1 - beta) interpolated_images.requires_grad_(True) # Calculate critic scores mixed_scores = critic(interpolated_images, alpha, train_step) # Take the gradient of the scores with respect to the images gradient = torch.autograd.grad( inputs=interpolated_images, outputs=mixed_scores, grad_outputs=torch.ones_like(mixed_scores), create_graph=True, retain_graph=True, )[0] gradient = gradient.view(gradient.shape[0], -1) gradient_norm = gradient.norm(2, dim=1) gradient_penalty = torch.mean((gradient_norm - 1) ** 2) return gradient_penalty def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"): print("=> Saving checkpoint") checkpoint = { "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), } torch.save(checkpoint, filename) def load_checkpoint(checkpoint_file, model, optimizer, lr): print("=> Loading checkpoint") checkpoint = torch.load(checkpoint_file, map_location="cuda") model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) # If we don't do this then it will just have learning rate of old checkpoint # and it will lead to many hours of debugging \: for param_group in optimizer.param_groups: param_group["lr"] = lr def seed_everything(seed=42): os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False