diff --git a/ML/Pytorch/GANs/ProGAN/README.md b/ML/Pytorch/GANs/ProGAN/README.md index e5b0f18..18a6eb5 100644 --- a/ML/Pytorch/GANs/ProGAN/README.md +++ b/ML/Pytorch/GANs/ProGAN/README.md @@ -4,16 +4,15 @@ A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to ## Results The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2. -|First is 64 random examples (not cherry picked) and second is more cherry picked examples. | +|First is some more cherrypicked examples and second is just sampled from random latent vectors| |:---:| -|![](results/64_examples.png)| |![](results/result1.png)| +|![](results/64_examples.png)| ### Celeb-HQ dataset The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/lamsimon/celebahq). - ### Download pretrained weights Pretrained weights [here](https://github.com/aladdinpersson/Machine-Learning-Collection/releases/download/1.0/ProGAN_weights.zip). diff --git a/ML/Pytorch/GANs/ProGAN/config.py b/ML/Pytorch/GANs/ProGAN/config.py index b9ef9ea..8e8c25e 100644 --- a/ML/Pytorch/GANs/ProGAN/config.py +++ b/ML/Pytorch/GANs/ProGAN/config.py @@ -8,7 +8,7 @@ CHECKPOINT_GEN = "generator.pth" CHECKPOINT_CRITIC = "critic.pth" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SAVE_MODEL = True -LOAD_MODEL = True +LOAD_MODEL = False LEARNING_RATE = 1e-3 BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4] CHANNELS_IMG = 3 diff --git a/ML/Pytorch/GANs/ProGAN/model.py b/ML/Pytorch/GANs/ProGAN/model.py index fb12b80..f22e1d3 100644 --- a/ML/Pytorch/GANs/ProGAN/model.py +++ b/ML/Pytorch/GANs/ProGAN/model.py @@ -134,7 +134,7 @@ class Generator(nn.Module): class Discriminator(nn.Module): - def __init__(self, z_dim, in_channels, img_channels=3): + def __init__(self, in_channels, img_channels=3): super(Discriminator, self).__init__() self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([]) self.leaky = nn.LeakyReLU(0.2) diff --git a/ML/Pytorch/GANs/ProGAN/train.py b/ML/Pytorch/GANs/ProGAN/train.py index f283809..41b6b68 100644 --- a/ML/Pytorch/GANs/ProGAN/train.py +++ b/ML/Pytorch/GANs/ProGAN/train.py @@ -11,6 +11,7 @@ from utils import ( plot_to_tensorboard, save_checkpoint, load_checkpoint, + generate_examples, ) from model import Discriminator, Generator from math import log2 @@ -130,9 +131,8 @@ def train_fn( def main(): # initialize gen and disc, note: discriminator should be called critic, # according to WGAN paper (since it no longer outputs between [0, 1]) - # but really who cares.. gen = Generator( - config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG + config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG ).to(config.DEVICE) critic = Discriminator( config.IN_CHANNELS, img_channels=config.CHANNELS_IMG @@ -147,7 +147,7 @@ def main(): scaler_gen = torch.cuda.amp.GradScaler() # for tensorboard plotting - writer = SummaryWriter(f"logs/gan1") + writer = SummaryWriter(f"logs/gan") if config.LOAD_MODEL: load_checkpoint( @@ -163,6 +163,10 @@ def main(): tensorboard_step = 0 # start at step that corresponds to img size that we set in config step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4)) + + generate_examples(gen, step) + import sys + sys.exit() for num_epochs in config.PROGRESSIVE_EPOCHS[step:]: alpha = 1e-5 # start with very low alpha loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4 diff --git a/ML/Pytorch/GANs/ProGAN/utils.py b/ML/Pytorch/GANs/ProGAN/utils.py index a81f0a1..92d32d4 100644 --- a/ML/Pytorch/GANs/ProGAN/utils.py +++ b/ML/Pytorch/GANs/ProGAN/utils.py @@ -4,6 +4,9 @@ import numpy as np import os import torchvision import torch.nn as nn +import config +from torchvision.utils import save_image +from scipy.stats import truncnorm # Print losses occasionally and print to tensorboard def plot_to_tensorboard( @@ -12,7 +15,7 @@ def plot_to_tensorboard( writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step) with torch.no_grad(): - # take out (up to) 32 examples + # take out (up to) 8 examples to plot img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True) img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True) writer.add_image("Real", img_grid_real, global_step=tensorboard_step) @@ -72,4 +75,18 @@ def seed_everything(seed=42): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False +def generate_examples(gen, steps, truncation=0.7, n=100): + """ + Tried using truncation trick here but not sure it actually helped anything, you can + remove it if you like and just sample from torch.randn + """ + gen.eval() + alpha = 1.0 + for i in range(n): + with torch.no_grad(): + noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32) + img = gen(noise, alpha, steps) + save_image(img*0.5+0.5, f"saved_examples/img_{i}.png") + gen.train() +