Files
Machine-Learning-Collection/ML/Pytorch/GANs/ProGAN/train.py
2021-03-06 11:01:51 +01:00

165 lines
5.4 KiB
Python

""" Training of ProGAN using WGAN-GP loss"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
from model import Discriminator, Generator
from math import log2
from tqdm import tqdm
import time
torch.backends.cudnn.benchmarks = True
torch.manual_seed(0)
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZES = [128, 128, 64, 16, 8, 4, 2, 2, 1]
IMAGE_SIZE = 128
CHANNELS_IMG = 3
Z_DIM = 128
IN_CHANNELS = 128
CRITIC_ITERATIONS = 1
LAMBDA_GP = 10
NUM_STEPS = int(log2(IMAGE_SIZE / 4)) + 1
PROGRESSIVE_EPOCHS = [2 ** i for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
PROGRESSIVE_EPOCHS = [8 for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
fixed_noise = torch.randn(8, Z_DIM, 1, 1).to(device)
NUM_WORKERS = 4
def get_loader(image_size):
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)],
[0.5 for _ in range(CHANNELS_IMG)],
),
]
)
batch_size = BATCH_SIZES[int(log2(image_size/4))]
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
return loader, dataset
def train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen,
tensorboard_step,
writer,
):
start = time.time()
total_time = 0
training = tqdm(loader, leave=True)
for batch_idx, (real, _) in enumerate(training):
real = real.to(device)
cur_batch_size = real.shape[0]
model_start = time.time()
# Train Critic: max E[critic(real)] - E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
for _ in range(CRITIC_ITERATIONS):
critic.zero_grad()
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step).reshape(-1)
critic_fake = critic(fake, alpha, step).reshape(-1)
gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ LAMBDA_GP * gp
)
loss_critic.backward(retain_graph=True)
opt_critic.step()
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
gen.zero_grad()
fake = gen(noise, alpha, step)
gen_fake = critic(fake, alpha, step).reshape(-1)
loss_gen = -torch.mean(gen_fake)
loss_gen.backward()
opt_gen.step()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
)
alpha = min(alpha, 1)
total_time += time.time()-model_start
if batch_idx % 300 == 0:
with torch.no_grad():
fixed_fakes = gen(fixed_noise, alpha, step)
plot_to_tensorboard(
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
)
tensorboard_step += 1
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
return tensorboard_step, alpha
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, IN_CHANNELS, img_size=IMAGE_SIZE, img_channels=CHANNELS_IMG).to(device)
critic = Discriminator(IMAGE_SIZE, Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(device)
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan")
load_checkpoint(torch.load("celeba_wgan_gp.pth.tar"), gen, critic)
gen.train()
critic.train()
tensorboard_step = 0
for step, num_epochs in enumerate(PROGRESSIVE_EPOCHS):
alpha = 0.01
if step < 3:
continue
if step == 4:
print(f"Img size is: {4*2**step}")
loader, dataset = get_loader(4 * 2 ** step)
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
tensorboard_step, alpha = train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen,
tensorboard_step,
writer,
)
checkpoint = {'gen': gen.state_dict(),
'critic': critic.state_dict(),
'opt_gen': opt_gen.state_dict(),
'opt_critic': opt_critic.state_dict()}
save_checkpoint(checkpoint)
if __name__ == "__main__":
main()