damn, copied over wrong train file for ProGAN (will check this more thoroughly before the video is up too

This commit is contained in:
Aladdin Persson
2021-03-19 20:21:14 +01:00
parent bd6db84daa
commit c72d1d6a31

View File

@@ -1,38 +1,49 @@
""" Training of ProGAN using WGAN-GP loss""" """ Training of ProGAN using WGAN-GP loss"""
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torchvision
import torchvision.datasets as datasets import torchvision.datasets as datasets
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint from utils import (
gradient_penalty,
plot_to_tensorboard,
save_checkpoint,
load_checkpoint,
)
from model import Discriminator, Generator from model import Discriminator, Generator
from math import log2 from math import log2
from tqdm import tqdm from tqdm import tqdm
import time
import config import config
torch.backends.cudnn.benchmarks = True torch.backends.cudnn.benchmarks = True
def get_loader(image_size): def get_loader(image_size):
transform = transforms.Compose( transform = transforms.Compose(
[ [
transforms.Resize((image_size, image_size)), transforms.Resize((image_size, image_size)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize( transforms.Normalize(
[0.5 for _ in range(config.CHANNELS_IMG)], [0.5 for _ in range(config.CHANNELS_IMG)],
[0.5 for _ in range(config.CHANNELS_IMG)], [0.5 for _ in range(config.CHANNELS_IMG)],
), ),
] ]
) )
batch_size = config.BATCH_SIZES[int(log2(image_size/4))] batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform) dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True) loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return loader, dataset return loader, dataset
def train_fn( def train_fn(
critic, critic,
gen, gen,
@@ -47,91 +58,96 @@ def train_fn(
scaler_gen, scaler_gen,
scaler_critic, scaler_critic,
): ):
start = time.time()
total_time = 0
loop = tqdm(loader, leave=True) loop = tqdm(loader, leave=True)
losses_critic = [] # critic_losses = []
reals = 0
fakes = 0
for batch_idx, (real, _) in enumerate(loop): for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE) real = real.to(config.DEVICE)
cur_batch_size = real.shape[0] cur_batch_size = real.shape[0]
model_start = time.time()
for _ in range(4): # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
# Train Critic: max E[critic(real)] - E[critic(fake)] # which is equivalent to minimizing the negative of the expression
# which is equivalent to minimizing the negative of the expression noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
for _ in range(config.CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step) fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step).reshape(-1) critic_real = critic(real, alpha, step)
critic_fake = critic(fake, alpha, step).reshape(-1) critic_fake = critic(fake.detach(), alpha, step)
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE) reals += critic_real.mean().item()
loss_critic = ( fakes += critic_fake.mean().item()
-(torch.mean(critic_real) - torch.mean(critic_fake)) gp = gradient_penalty(critic, real, fake, device=config.DEVICE)
+ config.LAMBDA_GP * gp loss_critic = (
) -(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
losses_critic.append(loss_critic.item()) opt_critic.zero_grad()
opt_critic.zero_grad() scaler_critic.scale(loss_critic).backward()
scaler_critic.scale(loss_critic).backward() scaler_critic.step(opt_critic)
scaler_critic.step(opt_critic) scaler_critic.update()
scaler_critic.update()
#loss_critic.backward(retain_graph=True)
#opt_critic.step()
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)] # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step) gen_fake = critic(fake, alpha, step)
gen_fake = critic(fake, alpha, step).reshape(-1) loss_gen = -torch.mean(gen_fake)
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad() opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward() scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen) scaler_gen.step(opt_gen)
scaler_gen.update() scaler_gen.update()
#gen.zero_grad()
#loss_gen.backward()
#opt_gen.step()
# Update alpha and ensure less than 1 # Update alpha and ensure less than 1
alpha += cur_batch_size / ( alpha += cur_batch_size / (
(config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step (config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
) )
alpha = min(alpha, 1) alpha = min(alpha, 1)
total_time += time.time()-model_start
if batch_idx % 10 == 0: if batch_idx % 500 == 0:
print(alpha)
with torch.no_grad(): with torch.no_grad():
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
plot_to_tensorboard( plot_to_tensorboard(
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step writer,
loss_critic.item(),
loss_gen.item(),
real.detach(),
fixed_fakes.detach(),
tensorboard_step,
) )
tensorboard_step += 1 tensorboard_step += 1
mean_loss = sum(losses_critic) / len(losses_critic) loop.set_postfix(
loop.set_postfix(loss=mean_loss) reals=reals / (batch_idx + 1),
fakes=fakes / (batch_idx + 1),
gp=gp.item(),
loss_critic=loss_critic.item(),
)
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
return tensorboard_step, alpha return tensorboard_step, alpha
def main(): def main():
# initialize gen and disc, note: discriminator should be called critic, # initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1]) # according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_size=config.IMAGE_SIZE, img_channels=config.CHANNELS_IMG).to(config.DEVICE) # but really who cares..
critic = Discriminator(config.IMAGE_SIZE, config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE) gen = Generator(
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
# initializate optimizer # initialize optimizers and scalers for FP16 training
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)) opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)) opt_critic = optim.Adam(
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)
scaler_critic = torch.cuda.amp.GradScaler() scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler() scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting # for tensorboard plotting
writer = SummaryWriter(f"logs/gan") writer = SummaryWriter(f"logs/gan1")
if config.LOAD_MODEL: if config.LOAD_MODEL:
load_checkpoint( load_checkpoint(
@@ -145,12 +161,13 @@ def main():
critic.train() critic.train()
tensorboard_step = 0 tensorboard_step = 0
step = int(log2(config.START_TRAIN_AT_IMG_SIZE/4)) # start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]: for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 0.01 alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3 loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
print(f"Current image size: {4*2**step}") print(f"Current image size: {4 * 2 ** step}")
for epoch in range(num_epochs): for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]") print(f"Epoch [{epoch+1}/{num_epochs}]")
tensorboard_step, alpha = train_fn( tensorboard_step, alpha = train_fn(
@@ -172,7 +189,8 @@ def main():
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN) save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC) save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
step += 1 step += 1 # progress to the next img size
if __name__ == "__main__": if __name__ == "__main__":
main() main()