cyclegan, progan

This commit is contained in:
Aladdin Persson
2021-03-11 15:50:44 +01:00
parent 91b1fd156c
commit 2c53205f12
27 changed files with 276 additions and 238 deletions

View File

@@ -1,54 +1,50 @@
""" Training of ProGAN using WGAN-GP loss"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
from utils import (
gradient_penalty,
plot_to_tensorboard,
save_checkpoint,
load_checkpoint,
generate_examples,
)
from model import Discriminator, Generator
from math import log2
from tqdm import tqdm
import time
import config
torch.backends.cudnn.benchmarks = True
torch.manual_seed(0)
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZES = [128, 128, 64, 16, 8, 4, 2, 2, 1]
IMAGE_SIZE = 128
CHANNELS_IMG = 3
Z_DIM = 128
IN_CHANNELS = 128
CRITIC_ITERATIONS = 1
LAMBDA_GP = 10
NUM_STEPS = int(log2(IMAGE_SIZE / 4)) + 1
PROGRESSIVE_EPOCHS = [2 ** i for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
PROGRESSIVE_EPOCHS = [8 for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
fixed_noise = torch.randn(8, Z_DIM, 1, 1).to(device)
NUM_WORKERS = 4
def get_loader(image_size):
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)],
[0.5 for _ in range(CHANNELS_IMG)],
[0.5 for _ in range(config.CHANNELS_IMG)],
[0.5 for _ in range(config.CHANNELS_IMG)],
),
]
)
batch_size = BATCH_SIZES[int(log2(image_size/4))]
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return loader, dataset
def train_fn(
critic,
gen,
@@ -60,85 +56,119 @@ def train_fn(
opt_gen,
tensorboard_step,
writer,
scaler_gen,
scaler_critic,
):
start = time.time()
total_time = 0
training = tqdm(loader, leave=True)
for batch_idx, (real, _) in enumerate(training):
real = real.to(device)
loop = tqdm(loader, leave=True)
# critic_losses = []
reals = 0
fakes = 0
for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE)
cur_batch_size = real.shape[0]
model_start = time.time()
# Train Critic: max E[critic(real)] - E[critic(fake)]
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
for _ in range(CRITIC_ITERATIONS):
critic.zero_grad()
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step).reshape(-1)
critic_fake = critic(fake, alpha, step).reshape(-1)
gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
reals += critic_real.mean().item()
fakes += critic_fake.mean().item()
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ LAMBDA_GP * gp
+ config.LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
loss_critic.backward(retain_graph=True)
opt_critic.step()
opt_critic.zero_grad()
scaler_critic.scale(loss_critic).backward()
scaler_critic.step(opt_critic)
scaler_critic.update()
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
gen.zero_grad()
fake = gen(noise, alpha, step)
gen_fake = critic(fake, alpha, step).reshape(-1)
loss_gen = -torch.mean(gen_fake)
loss_gen.backward()
opt_gen.step()
with torch.cuda.amp.autocast():
gen_fake = critic(fake, alpha, step)
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen)
scaler_gen.update()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
)
alpha = min(alpha, 1)
total_time += time.time()-model_start
if batch_idx % 300 == 0:
if batch_idx % 500 == 0:
with torch.no_grad():
fixed_fakes = gen(fixed_noise, alpha, step)
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
plot_to_tensorboard(
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
writer,
loss_critic.item(),
loss_gen.item(),
real.detach(),
fixed_fakes.detach(),
tensorboard_step,
)
tensorboard_step += 1
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
loop.set_postfix(
reals=reals / (batch_idx + 1),
fakes=fakes / (batch_idx + 1),
gp=gp.item(),
loss_critic=loss_critic.item(),
)
return tensorboard_step, alpha
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, IN_CHANNELS, img_size=IMAGE_SIZE, img_channels=CHANNELS_IMG).to(device)
critic = Discriminator(IMAGE_SIZE, Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(device)
# but really who cares..
gen = Generator(
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
# initialize optimizers and scalers for FP16 training
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan")
writer = SummaryWriter(f"logs/gan1")
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE,
)
load_checkpoint(torch.load("celeba_wgan_gp.pth.tar"), gen, critic)
gen.train()
critic.train()
tensorboard_step = 0
for step, num_epochs in enumerate(PROGRESSIVE_EPOCHS):
alpha = 0.01
if step < 3:
continue
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
print(f"Current image size: {4 * 2 ** step}")
if step == 4:
print(f"Img size is: {4*2**step}")
loader, dataset = get_loader(4 * 2 ** step)
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
tensorboard_step, alpha = train_fn(
@@ -152,14 +182,16 @@ def main():
opt_gen,
tensorboard_step,
writer,
scaler_gen,
scaler_critic,
)
checkpoint = {'gen': gen.state_dict(),
'critic': critic.state_dict(),
'opt_gen': opt_gen.state_dict(),
'opt_critic': opt_critic.state_dict()}
if config.SAVE_MODEL:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
step += 1 # progress to the next img size
save_checkpoint(checkpoint)
if __name__ == "__main__":
main()
main()