Files
Machine-Learning-Collection/ML/Pytorch/GANs/ESRGAN/train.py
2021-05-15 14:58:41 +02:00

155 lines
4.2 KiB
Python

import torch
import config
from torch import nn
from torch import optim
from utils import gradient_penalty, load_checkpoint, save_checkpoint, plot_examples
from loss import VGGLoss
from torch.utils.data import DataLoader
from model import Generator, Discriminator, initialize_weights
from tqdm import tqdm
from dataset import MyImageFolder
from torch.utils.tensorboard import SummaryWriter
torch.backends.cudnn.benchmark = True
def train_fn(
loader,
disc,
gen,
opt_gen,
opt_disc,
l1,
vgg_loss,
g_scaler,
d_scaler,
writer,
tb_step,
):
loop = tqdm(loader, leave=True)
for idx, (low_res, high_res) in enumerate(loop):
high_res = high_res.to(config.DEVICE)
low_res = low_res.to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(low_res)
critic_real = disc(high_res)
critic_fake = disc(fake.detach())
gp = gradient_penalty(disc, high_res, fake, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
)
opt_disc.zero_grad()
d_scaler.scale(loss_critic).backward()
d_scaler.step(opt_disc)
d_scaler.update()
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
with torch.cuda.amp.autocast():
l1_loss = 1e-2 * l1(fake, high_res)
adversarial_loss = 5e-3 * -torch.mean(disc(fake))
loss_for_vgg = vgg_loss(fake, high_res)
gen_loss = l1_loss + loss_for_vgg + adversarial_loss
opt_gen.zero_grad()
g_scaler.scale(gen_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
writer.add_scalar("Critic loss", loss_critic.item(), global_step=tb_step)
tb_step += 1
if idx % 100 == 0 and idx > 0:
plot_examples("test_images/", gen)
loop.set_postfix(
gp=gp.item(),
critic=loss_critic.item(),
l1=l1_loss.item(),
vgg=loss_for_vgg.item(),
adversarial=adversarial_loss.item(),
)
return tb_step
def main():
dataset = MyImageFolder(root_dir="data/")
loader = DataLoader(
dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
pin_memory=True,
num_workers=config.NUM_WORKERS,
)
gen = Generator(in_channels=3).to(config.DEVICE)
disc = Discriminator(in_channels=3).to(config.DEVICE)
initialize_weights(gen)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
writer = SummaryWriter("logs")
tb_step = 0
l1 = nn.L1Loss()
gen.train()
disc.train()
vgg_loss = VGGLoss()
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN,
gen,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISC,
disc,
opt_disc,
config.LEARNING_RATE,
)
for epoch in range(config.NUM_EPOCHS):
tb_step = train_fn(
loader,
disc,
gen,
opt_gen,
opt_disc,
l1,
vgg_loss,
g_scaler,
d_scaler,
writer,
tb_step,
)
if config.SAVE_MODEL:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
if __name__ == "__main__":
try_model = True
if try_model:
# Will just use pretrained weights and run on images
# in test_images/ and save the ones to SR in saved/
gen = Generator(in_channels=3).to(config.DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
load_checkpoint(
config.CHECKPOINT_GEN,
gen,
opt_gen,
config.LEARNING_RATE,
)
plot_examples("test_images/", gen)
else:
# This will train from scratch
main()