mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
updated progan
This commit is contained in:
@@ -4,16 +4,15 @@ A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to
|
||||
## Results
|
||||
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.
|
||||
|
||||
|First is 64 random examples (not cherry picked) and second is more cherry picked examples. |
|
||||
|First is some more cherrypicked examples and second is just sampled from random latent vectors|
|
||||
|:---:|
|
||||
||
|
||||
||
|
||||
||
|
||||
|
||||
|
||||
### Celeb-HQ dataset
|
||||
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/lamsimon/celebahq).
|
||||
|
||||
|
||||
### Download pretrained weights
|
||||
Pretrained weights [here](https://github.com/aladdinpersson/Machine-Learning-Collection/releases/download/1.0/ProGAN_weights.zip).
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ CHECKPOINT_GEN = "generator.pth"
|
||||
CHECKPOINT_CRITIC = "critic.pth"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
SAVE_MODEL = True
|
||||
LOAD_MODEL = True
|
||||
LOAD_MODEL = False
|
||||
LEARNING_RATE = 1e-3
|
||||
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
|
||||
CHANNELS_IMG = 3
|
||||
|
||||
@@ -134,7 +134,7 @@ class Generator(nn.Module):
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, z_dim, in_channels, img_channels=3):
|
||||
def __init__(self, in_channels, img_channels=3):
|
||||
super(Discriminator, self).__init__()
|
||||
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
||||
self.leaky = nn.LeakyReLU(0.2)
|
||||
|
||||
@@ -11,6 +11,7 @@ from utils import (
|
||||
plot_to_tensorboard,
|
||||
save_checkpoint,
|
||||
load_checkpoint,
|
||||
generate_examples,
|
||||
)
|
||||
from model import Discriminator, Generator
|
||||
from math import log2
|
||||
@@ -130,9 +131,8 @@ def train_fn(
|
||||
def main():
|
||||
# initialize gen and disc, note: discriminator should be called critic,
|
||||
# according to WGAN paper (since it no longer outputs between [0, 1])
|
||||
# but really who cares..
|
||||
gen = Generator(
|
||||
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
critic = Discriminator(
|
||||
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
@@ -147,7 +147,7 @@ def main():
|
||||
scaler_gen = torch.cuda.amp.GradScaler()
|
||||
|
||||
# for tensorboard plotting
|
||||
writer = SummaryWriter(f"logs/gan1")
|
||||
writer = SummaryWriter(f"logs/gan")
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
@@ -163,6 +163,10 @@ def main():
|
||||
tensorboard_step = 0
|
||||
# start at step that corresponds to img size that we set in config
|
||||
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
|
||||
|
||||
generate_examples(gen, step)
|
||||
import sys
|
||||
sys.exit()
|
||||
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
|
||||
alpha = 1e-5 # start with very low alpha
|
||||
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
|
||||
|
||||
@@ -4,6 +4,9 @@ import numpy as np
|
||||
import os
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
import config
|
||||
from torchvision.utils import save_image
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
def plot_to_tensorboard(
|
||||
@@ -12,7 +15,7 @@ def plot_to_tensorboard(
|
||||
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
|
||||
|
||||
with torch.no_grad():
|
||||
# take out (up to) 32 examples
|
||||
# take out (up to) 8 examples to plot
|
||||
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
|
||||
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
|
||||
@@ -72,4 +75,18 @@ def seed_everything(seed=42):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def generate_examples(gen, steps, truncation=0.7, n=100):
|
||||
"""
|
||||
Tried using truncation trick here but not sure it actually helped anything, you can
|
||||
remove it if you like and just sample from torch.randn
|
||||
"""
|
||||
gen.eval()
|
||||
alpha = 1.0
|
||||
for i in range(n):
|
||||
with torch.no_grad():
|
||||
noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
|
||||
img = gen(noise, alpha, steps)
|
||||
save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
|
||||
gen.train()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user