mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
update to progan
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# ProGAN
|
||||
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
|
||||
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is pretty close to the original paper (I'll include some examples results below) but because of time limitation I only trained to 256x256 and on lower model size than they did in the paper. Making the number of channels to 512 instead of 256 as I trained it would probably make the results even better :)
|
||||
|
||||
## Results
|
||||
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.
|
||||
|
||||
21
ML/Pytorch/GANs/ProGAN/config.py
Normal file
21
ML/Pytorch/GANs/ProGAN/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import cv2
|
||||
import torch
|
||||
from math import log2
|
||||
|
||||
START_TRAIN_AT_IMG_SIZE = 4
|
||||
DATASET = 'celeb_dataset'
|
||||
CHECKPOINT_GEN = "generator.pth"
|
||||
CHECKPOINT_CRITIC = "critic.pth"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
SAVE_MODEL = True
|
||||
LOAD_MODEL = True
|
||||
LEARNING_RATE = 1e-3
|
||||
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
|
||||
CHANNELS_IMG = 3
|
||||
Z_DIM = 256 # should be 512 in original paper
|
||||
IN_CHANNELS = 256 # should be 512 in original paper
|
||||
CRITIC_ITERATIONS = 1
|
||||
LAMBDA_GP = 10
|
||||
PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
|
||||
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
|
||||
NUM_WORKERS = 4
|
||||
@@ -1,4 +0,0 @@
|
||||
it = iter(l)
|
||||
|
||||
for el in it:
|
||||
print(el, next(it))
|
||||
@@ -1,50 +1,38 @@
|
||||
""" Training of ProGAN using WGAN-GP loss"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import (
|
||||
gradient_penalty,
|
||||
plot_to_tensorboard,
|
||||
save_checkpoint,
|
||||
load_checkpoint,
|
||||
generate_examples,
|
||||
)
|
||||
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
|
||||
from model import Discriminator, Generator
|
||||
from math import log2
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import config
|
||||
|
||||
torch.backends.cudnn.benchmarks = True
|
||||
|
||||
|
||||
def get_loader(image_size):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(config.CHANNELS_IMG)],
|
||||
[0.5 for _ in range(config.CHANNELS_IMG)],
|
||||
),
|
||||
]
|
||||
)
|
||||
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
|
||||
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
)
|
||||
batch_size = config.BATCH_SIZES[int(log2(image_size/4))]
|
||||
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
|
||||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
|
||||
return loader, dataset
|
||||
|
||||
|
||||
def train_fn(
|
||||
critic,
|
||||
gen,
|
||||
@@ -59,96 +47,91 @@ def train_fn(
|
||||
scaler_gen,
|
||||
scaler_critic,
|
||||
):
|
||||
start = time.time()
|
||||
total_time = 0
|
||||
loop = tqdm(loader, leave=True)
|
||||
# critic_losses = []
|
||||
reals = 0
|
||||
fakes = 0
|
||||
losses_critic = []
|
||||
|
||||
for batch_idx, (real, _) in enumerate(loop):
|
||||
real = real.to(config.DEVICE)
|
||||
cur_batch_size = real.shape[0]
|
||||
model_start = time.time()
|
||||
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
|
||||
# which is equivalent to minimizing the negative of the expression
|
||||
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
|
||||
for _ in range(4):
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)]
|
||||
# which is equivalent to minimizing the negative of the expression
|
||||
for _ in range(config.CRITIC_ITERATIONS):
|
||||
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
fake = gen(noise, alpha, step)
|
||||
critic_real = critic(real, alpha, step)
|
||||
critic_fake = critic(fake.detach(), alpha, step)
|
||||
reals += critic_real.mean().item()
|
||||
fakes += critic_fake.mean().item()
|
||||
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
+ config.LAMBDA_GP * gp
|
||||
+ (0.001 * torch.mean(critic_real ** 2))
|
||||
)
|
||||
with torch.cuda.amp.autocast():
|
||||
fake = gen(noise, alpha, step)
|
||||
critic_real = critic(real, alpha, step).reshape(-1)
|
||||
critic_fake = critic(fake, alpha, step).reshape(-1)
|
||||
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
+ config.LAMBDA_GP * gp
|
||||
)
|
||||
|
||||
opt_critic.zero_grad()
|
||||
scaler_critic.scale(loss_critic).backward()
|
||||
scaler_critic.step(opt_critic)
|
||||
scaler_critic.update()
|
||||
losses_critic.append(loss_critic.item())
|
||||
opt_critic.zero_grad()
|
||||
scaler_critic.scale(loss_critic).backward()
|
||||
scaler_critic.step(opt_critic)
|
||||
scaler_critic.update()
|
||||
#loss_critic.backward(retain_graph=True)
|
||||
#opt_critic.step()
|
||||
|
||||
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
|
||||
with torch.cuda.amp.autocast():
|
||||
gen_fake = critic(fake, alpha, step)
|
||||
loss_gen = -torch.mean(gen_fake)
|
||||
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
|
||||
with torch.cuda.amp.autocast():
|
||||
fake = gen(noise, alpha, step)
|
||||
gen_fake = critic(fake, alpha, step).reshape(-1)
|
||||
loss_gen = -torch.mean(gen_fake)
|
||||
|
||||
opt_gen.zero_grad()
|
||||
scaler_gen.scale(loss_gen).backward()
|
||||
scaler_gen.step(opt_gen)
|
||||
scaler_gen.update()
|
||||
opt_gen.zero_grad()
|
||||
scaler_gen.scale(loss_gen).backward()
|
||||
scaler_gen.step(opt_gen)
|
||||
scaler_gen.update()
|
||||
#gen.zero_grad()
|
||||
#loss_gen.backward()
|
||||
#opt_gen.step()
|
||||
|
||||
# Update alpha and ensure less than 1
|
||||
alpha += cur_batch_size / (
|
||||
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
|
||||
(config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
|
||||
)
|
||||
alpha = min(alpha, 1)
|
||||
total_time += time.time()-model_start
|
||||
|
||||
if batch_idx % 500 == 0:
|
||||
if batch_idx % 10 == 0:
|
||||
print(alpha)
|
||||
with torch.no_grad():
|
||||
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
|
||||
fixed_fakes = gen(config.FIXED_NOISE, alpha, step)
|
||||
plot_to_tensorboard(
|
||||
writer,
|
||||
loss_critic.item(),
|
||||
loss_gen.item(),
|
||||
real.detach(),
|
||||
fixed_fakes.detach(),
|
||||
tensorboard_step,
|
||||
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
|
||||
)
|
||||
tensorboard_step += 1
|
||||
|
||||
loop.set_postfix(
|
||||
reals=reals / (batch_idx + 1),
|
||||
fakes=fakes / (batch_idx + 1),
|
||||
gp=gp.item(),
|
||||
loss_critic=loss_critic.item(),
|
||||
)
|
||||
mean_loss = sum(losses_critic) / len(losses_critic)
|
||||
loop.set_postfix(loss=mean_loss)
|
||||
|
||||
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
|
||||
return tensorboard_step, alpha
|
||||
|
||||
|
||||
def main():
|
||||
# initialize gen and disc, note: discriminator should be called critic,
|
||||
# according to WGAN paper (since it no longer outputs between [0, 1])
|
||||
# but really who cares..
|
||||
gen = Generator(
|
||||
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
critic = Discriminator(
|
||||
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_size=config.IMAGE_SIZE, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
|
||||
critic = Discriminator(config.IMAGE_SIZE, config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
|
||||
|
||||
# initialize optimizers and scalers for FP16 training
|
||||
# initializate optimizer
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
|
||||
opt_critic = optim.Adam(
|
||||
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
|
||||
)
|
||||
opt_critic = optim.Adam(critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
|
||||
scaler_critic = torch.cuda.amp.GradScaler()
|
||||
scaler_gen = torch.cuda.amp.GradScaler()
|
||||
|
||||
# for tensorboard plotting
|
||||
writer = SummaryWriter(f"logs/gan1")
|
||||
writer = SummaryWriter(f"logs/gan")
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
@@ -162,13 +145,12 @@ def main():
|
||||
critic.train()
|
||||
|
||||
tensorboard_step = 0
|
||||
# start at step that corresponds to img size that we set in config
|
||||
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
|
||||
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
|
||||
alpha = 1e-5 # start with very low alpha
|
||||
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
|
||||
print(f"Current image size: {4 * 2 ** step}")
|
||||
step = int(log2(config.START_TRAIN_AT_IMG_SIZE/4))
|
||||
|
||||
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
|
||||
alpha = 0.01
|
||||
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3
|
||||
print(f"Current image size: {4*2**step}")
|
||||
for epoch in range(num_epochs):
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}]")
|
||||
tensorboard_step, alpha = train_fn(
|
||||
@@ -190,8 +172,7 @@ def main():
|
||||
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
|
||||
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
|
||||
|
||||
step += 1 # progress to the next img size
|
||||
|
||||
step += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
75
ML/Pytorch/GANs/ProGAN/utils.py
Normal file
75
ML/Pytorch/GANs/ProGAN/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
def plot_to_tensorboard(
|
||||
writer, loss_critic, loss_gen, real, fake, tensorboard_step
|
||||
):
|
||||
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
|
||||
|
||||
with torch.no_grad():
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
|
||||
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
|
||||
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * beta + fake.detach() * (1 - beta)
|
||||
interpolated_images.requires_grad_(True)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images, alpha, train_step)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||||
print("=> Loading checkpoint")
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cuda")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
# If we don't do this then it will just have learning rate of old checkpoint
|
||||
# and it will lead to many hours of debugging \:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
def seed_everything(seed=42):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user