mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
stylegan, esrgan, srgan code
This commit is contained in:
126
ML/Pytorch/GANs/StyleGAN/utils.py
Normal file
126
ML/Pytorch/GANs/StyleGAN/utils.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
def plot_to_tensorboard(
|
||||
writer, loss_critic, loss_gen, real, fake, tensorboard_step
|
||||
):
|
||||
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
|
||||
|
||||
with torch.no_grad():
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
|
||||
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
|
||||
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * beta + fake.detach() * (1 - beta)
|
||||
interpolated_images.requires_grad_(True)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images, alpha, train_step)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||||
print("=> Loading checkpoint")
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cuda")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
# If we don't do this then it will just have learning rate of old checkpoint
|
||||
# and it will lead to many hours of debugging \:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def seed_everything(seed=42):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
class EMA:
|
||||
# Found this useful (thanks alexis-jacq):
|
||||
# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/3
|
||||
def __init__(self, gamma=0.99, save=True, save_frequency=100, save_filename="ema_weights.pth"):
|
||||
"""
|
||||
Initialize the weight to which we will do the
|
||||
exponential moving average and the dictionary
|
||||
where we store the model parameters
|
||||
"""
|
||||
self.gamma = gamma
|
||||
self.registered = {}
|
||||
self.save_filename = save_filename
|
||||
self.save_frequency = save_frequency
|
||||
self.count = 0
|
||||
|
||||
if save_filename in os.listdir("."):
|
||||
self.registered = torch.load(self.save_filename)
|
||||
|
||||
if not save:
|
||||
warnings.warn("Note that the exponential moving average weights will not be saved to a .pth file!")
|
||||
|
||||
def register_weights(self, model):
|
||||
"""
|
||||
Registers the weights of the model which will
|
||||
later be used when we take the moving average
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.registered[name] = param.clone().detach()
|
||||
|
||||
def __call__(self, model):
|
||||
self.count += 1
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
new_weight = param.clone().detach() if name not in self.registered else self.gamma * param + (1 - self.gamma) * self.registered[name]
|
||||
self.registered[name] = new_weight
|
||||
|
||||
if self.count % self.save_frequency == 0:
|
||||
self.save_ema_weights()
|
||||
|
||||
def copy_weights_to(self, model):
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
param.data = self.registered[name]
|
||||
|
||||
def save_ema_weights(self):
|
||||
torch.save(self.registered, self.save_filename)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user