updated progan

This commit is contained in:
Aladdin Persson
2021-03-21 12:19:18 +01:00
parent c72d1d6a31
commit 59b1de7bfe
5 changed files with 29 additions and 9 deletions

View File

@@ -4,6 +4,9 @@ import numpy as np
import os
import torchvision
import torch.nn as nn
import config
from torchvision.utils import save_image
from scipy.stats import truncnorm
# Print losses occasionally and print to tensorboard
def plot_to_tensorboard(
@@ -12,7 +15,7 @@ def plot_to_tensorboard(
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
with torch.no_grad():
# take out (up to) 32 examples
# take out (up to) 8 examples to plot
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
@@ -72,4 +75,18 @@ def seed_everything(seed=42):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def generate_examples(gen, steps, truncation=0.7, n=100):
"""
Tried using truncation trick here but not sure it actually helped anything, you can
remove it if you like and just sample from torch.randn
"""
gen.eval()
alpha = 1.0
for i in range(n):
with torch.no_grad():
noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
img = gen(noise, alpha, steps)
save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
gen.train()