update to progan

This commit is contained in:
Aladdin Persson
2021-03-17 22:48:26 +01:00
parent dc7f4f4ee7
commit 06bd8204b3
5 changed files with 163 additions and 90 deletions

View File

@@ -1,5 +1,5 @@
# ProGAN
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is pretty close to the original paper (I'll include some examples results below) but because of time limitation I only trained to 256x256 and on lower model size than they did in the paper. Making the number of channels to 512 instead of 256 as I trained it would probably make the results even better :)
## Results
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.

View File

@@ -0,0 +1,21 @@
import cv2
import torch
from math import log2
START_TRAIN_AT_IMG_SIZE = 4
DATASET = 'celeb_dataset'
CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_MODEL = True
LOAD_MODEL = True
LEARNING_RATE = 1e-3
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
CHANNELS_IMG = 3
Z_DIM = 256 # should be 512 in original paper
IN_CHANNELS = 256 # should be 512 in original paper
CRITIC_ITERATIONS = 1
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
NUM_WORKERS = 4

View File

@@ -1,4 +0,0 @@
it = iter(l)
for el in it:
print(el, next(it))

View File

@@ -1,50 +1,38 @@
""" Training of ProGAN using WGAN-GP loss"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import (
gradient_penalty,
plot_to_tensorboard,
save_checkpoint,
load_checkpoint,
generate_examples,
)
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
from model import Discriminator, Generator
from math import log2
from tqdm import tqdm
import time
import config
torch.backends.cudnn.benchmarks = True
def get_loader(image_size):
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
[0.5 for _ in range(config.CHANNELS_IMG)],
[0.5 for _ in range(config.CHANNELS_IMG)],
),
]
)
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
batch_size = config.BATCH_SIZES[int(log2(image_size/4))]
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
return loader, dataset
def train_fn(
critic,
gen,
@@ -59,96 +47,91 @@ def train_fn(
scaler_gen,
scaler_critic,
):
start = time.time()
total_time = 0
loop = tqdm(loader, leave=True)
# critic_losses = []
reals = 0
fakes = 0
losses_critic = []
for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE)
cur_batch_size = real.shape[0]
model_start = time.time()
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
for _ in range(4):
# Train Critic: max E[critic(real)] - E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
for _ in range(config.CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
reals += critic_real.mean().item()
fakes += critic_fake.mean().item()
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step).reshape(-1)
critic_fake = critic(fake, alpha, step).reshape(-1)
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
)
opt_critic.zero_grad()
scaler_critic.scale(loss_critic).backward()
scaler_critic.step(opt_critic)
scaler_critic.update()
losses_critic.append(loss_critic.item())
opt_critic.zero_grad()
scaler_critic.scale(loss_critic).backward()
scaler_critic.step(opt_critic)
scaler_critic.update()
#loss_critic.backward(retain_graph=True)
#opt_critic.step()
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast():
gen_fake = critic(fake, alpha, step)
loss_gen = -torch.mean(gen_fake)
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
gen_fake = critic(fake, alpha, step).reshape(-1)
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen)
scaler_gen.update()
opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen)
scaler_gen.update()
#gen.zero_grad()
#loss_gen.backward()
#opt_gen.step()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
(config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
)
alpha = min(alpha, 1)
total_time += time.time()-model_start
if batch_idx % 500 == 0:
if batch_idx % 10 == 0:
print(alpha)
with torch.no_grad():
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
fixed_fakes = gen(config.FIXED_NOISE, alpha, step)
plot_to_tensorboard(
writer,
loss_critic.item(),
loss_gen.item(),
real.detach(),
fixed_fakes.detach(),
tensorboard_step,
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
)
tensorboard_step += 1
loop.set_postfix(
reals=reals / (batch_idx + 1),
fakes=fakes / (batch_idx + 1),
gp=gp.item(),
loss_critic=loss_critic.item(),
)
mean_loss = sum(losses_critic) / len(losses_critic)
loop.set_postfix(loss=mean_loss)
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
return tensorboard_step, alpha
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
# but really who cares..
gen = Generator(
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_size=config.IMAGE_SIZE, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
critic = Discriminator(config.IMAGE_SIZE, config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
# initialize optimizers and scalers for FP16 training
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)
opt_critic = optim.Adam(critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan1")
writer = SummaryWriter(f"logs/gan")
if config.LOAD_MODEL:
load_checkpoint(
@@ -162,13 +145,12 @@ def main():
critic.train()
tensorboard_step = 0
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
print(f"Current image size: {4 * 2 ** step}")
step = int(log2(config.START_TRAIN_AT_IMG_SIZE/4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 0.01
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3
print(f"Current image size: {4*2**step}")
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
tensorboard_step, alpha = train_fn(
@@ -190,8 +172,7 @@ def main():
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
step += 1 # progress to the next img size
step += 1
if __name__ == "__main__":
main()
main()

View File

@@ -0,0 +1,75 @@
import torch
import random
import numpy as np
import os
import torchvision
import torch.nn as nn
# Print losses occasionally and print to tensorboard
def plot_to_tensorboard(
writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
with torch.no_grad():
# take out (up to) 32 examples
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * beta + fake.detach() * (1 - beta)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images, alpha, train_step)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location="cuda")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False