mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
updated progan
This commit is contained in:
@@ -4,6 +4,9 @@ import numpy as np
|
||||
import os
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
import config
|
||||
from torchvision.utils import save_image
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
def plot_to_tensorboard(
|
||||
@@ -12,7 +15,7 @@ def plot_to_tensorboard(
|
||||
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
|
||||
|
||||
with torch.no_grad():
|
||||
# take out (up to) 32 examples
|
||||
# take out (up to) 8 examples to plot
|
||||
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
|
||||
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
|
||||
@@ -72,4 +75,18 @@ def seed_everything(seed=42):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def generate_examples(gen, steps, truncation=0.7, n=100):
|
||||
"""
|
||||
Tried using truncation trick here but not sure it actually helped anything, you can
|
||||
remove it if you like and just sample from torch.randn
|
||||
"""
|
||||
gen.eval()
|
||||
alpha = 1.0
|
||||
for i in range(n):
|
||||
with torch.no_grad():
|
||||
noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
|
||||
img = gen(noise, alpha, steps)
|
||||
save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
|
||||
gen.train()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user