updated progan

This commit is contained in:
Aladdin Persson
2021-03-21 12:19:18 +01:00
parent c72d1d6a31
commit 59b1de7bfe
5 changed files with 29 additions and 9 deletions

View File

@@ -4,16 +4,15 @@ A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to
## Results
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.
|First is 64 random examples (not cherry picked) and second is more cherry picked examples. |
|First is some more cherrypicked examples and second is just sampled from random latent vectors|
|:---:|
|![](results/64_examples.png)|
|![](results/result1.png)|
|![](results/64_examples.png)|
### Celeb-HQ dataset
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/lamsimon/celebahq).
### Download pretrained weights
Pretrained weights [here](https://github.com/aladdinpersson/Machine-Learning-Collection/releases/download/1.0/ProGAN_weights.zip).

View File

@@ -8,7 +8,7 @@ CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_MODEL = True
LOAD_MODEL = True
LOAD_MODEL = False
LEARNING_RATE = 1e-3
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
CHANNELS_IMG = 3

View File

@@ -134,7 +134,7 @@ class Generator(nn.Module):
class Discriminator(nn.Module):
def __init__(self, z_dim, in_channels, img_channels=3):
def __init__(self, in_channels, img_channels=3):
super(Discriminator, self).__init__()
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
self.leaky = nn.LeakyReLU(0.2)

View File

@@ -11,6 +11,7 @@ from utils import (
plot_to_tensorboard,
save_checkpoint,
load_checkpoint,
generate_examples,
)
from model import Discriminator, Generator
from math import log2
@@ -130,9 +131,8 @@ def train_fn(
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
# but really who cares..
gen = Generator(
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
@@ -147,7 +147,7 @@ def main():
scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan1")
writer = SummaryWriter(f"logs/gan")
if config.LOAD_MODEL:
load_checkpoint(
@@ -163,6 +163,10 @@ def main():
tensorboard_step = 0
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
generate_examples(gen, step)
import sys
sys.exit()
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4

View File

@@ -4,6 +4,9 @@ import numpy as np
import os
import torchvision
import torch.nn as nn
import config
from torchvision.utils import save_image
from scipy.stats import truncnorm
# Print losses occasionally and print to tensorboard
def plot_to_tensorboard(
@@ -12,7 +15,7 @@ def plot_to_tensorboard(
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
with torch.no_grad():
# take out (up to) 32 examples
# take out (up to) 8 examples to plot
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
@@ -72,4 +75,18 @@ def seed_everything(seed=42):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def generate_examples(gen, steps, truncation=0.7, n=100):
"""
Tried using truncation trick here but not sure it actually helped anything, you can
remove it if you like and just sample from torch.randn
"""
gen.eval()
alpha = 1.0
for i in range(n):
with torch.no_grad():
noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
img = gen(noise, alpha, steps)
save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
gen.train()