damn, copied over wrong train file for ProGAN (will check this more thoroughly before the video is up too

This commit is contained in:
Aladdin Persson
2021-03-19 20:21:14 +01:00
parent bd6db84daa
commit c72d1d6a31

View File

@@ -1,38 +1,49 @@
""" Training of ProGAN using WGAN-GP loss"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
from utils import (
gradient_penalty,
plot_to_tensorboard,
save_checkpoint,
load_checkpoint,
)
from model import Discriminator, Generator
from math import log2
from tqdm import tqdm
import time
import config
torch.backends.cudnn.benchmarks = True
def get_loader(image_size):
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
[0.5 for _ in range(config.CHANNELS_IMG)],
[0.5 for _ in range(config.CHANNELS_IMG)],
),
]
)
batch_size = config.BATCH_SIZES[int(log2(image_size/4))]
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return loader, dataset
def train_fn(
critic,
gen,
@@ -47,91 +58,96 @@ def train_fn(
scaler_gen,
scaler_critic,
):
start = time.time()
total_time = 0
loop = tqdm(loader, leave=True)
losses_critic = []
# critic_losses = []
reals = 0
fakes = 0
for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE)
cur_batch_size = real.shape[0]
model_start = time.time()
for _ in range(4):
# Train Critic: max E[critic(real)] - E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
for _ in range(config.CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step).reshape(-1)
critic_fake = critic(fake, alpha, step).reshape(-1)
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
reals += critic_real.mean().item()
fakes += critic_fake.mean().item()
gp = gradient_penalty(critic, real, fake, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
losses_critic.append(loss_critic.item())
opt_critic.zero_grad()
scaler_critic.scale(loss_critic).backward()
scaler_critic.step(opt_critic)
scaler_critic.update()
#loss_critic.backward(retain_graph=True)
#opt_critic.step()
opt_critic.zero_grad()
scaler_critic.scale(loss_critic).backward()
scaler_critic.step(opt_critic)
scaler_critic.update()
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
gen_fake = critic(fake, alpha, step).reshape(-1)
loss_gen = -torch.mean(gen_fake)
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast():
gen_fake = critic(fake, alpha, step)
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen)
scaler_gen.update()
#gen.zero_grad()
#loss_gen.backward()
#opt_gen.step()
opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen)
scaler_gen.update()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
)
alpha = min(alpha, 1)
total_time += time.time()-model_start
if batch_idx % 10 == 0:
print(alpha)
if batch_idx % 500 == 0:
with torch.no_grad():
fixed_fakes = gen(config.FIXED_NOISE, alpha, step)
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
plot_to_tensorboard(
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
writer,
loss_critic.item(),
loss_gen.item(),
real.detach(),
fixed_fakes.detach(),
tensorboard_step,
)
tensorboard_step += 1
mean_loss = sum(losses_critic) / len(losses_critic)
loop.set_postfix(loss=mean_loss)
loop.set_postfix(
reals=reals / (batch_idx + 1),
fakes=fakes / (batch_idx + 1),
gp=gp.item(),
loss_critic=loss_critic.item(),
)
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
return tensorboard_step, alpha
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_size=config.IMAGE_SIZE, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
critic = Discriminator(config.IMAGE_SIZE, config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
# but really who cares..
gen = Generator(
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
# initializate optimizer
# initialize optimizers and scalers for FP16 training
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan")
writer = SummaryWriter(f"logs/gan1")
if config.LOAD_MODEL:
load_checkpoint(
@@ -145,12 +161,13 @@ def main():
critic.train()
tensorboard_step = 0
step = int(log2(config.START_TRAIN_AT_IMG_SIZE/4))
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 0.01
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3
print(f"Current image size: {4*2**step}")
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
print(f"Current image size: {4 * 2 ** step}")
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
tensorboard_step, alpha = train_fn(
@@ -172,7 +189,8 @@ def main():
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
step += 1
step += 1 # progress to the next img size
if __name__ == "__main__":
main()