import torch import config from torch import nn from torch import optim from utils import load_checkpoint, save_checkpoint, plot_examples from loss import VGGLoss from torch.utils.data import DataLoader from model import Generator, Discriminator from tqdm import tqdm from dataset import MyImageFolder torch.backends.cudnn.benchmark = True def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss): loop = tqdm(loader, leave=True) for idx, (low_res, high_res) in enumerate(loop): high_res = high_res.to(config.DEVICE) low_res = low_res.to(config.DEVICE) ### Train Discriminator: max log(D(x)) + log(1 - D(G(z))) fake = gen(low_res) disc_real = disc(high_res) disc_fake = disc(fake.detach()) disc_loss_real = bce( disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real) ) disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake)) loss_disc = disc_loss_fake + disc_loss_real opt_disc.zero_grad() loss_disc.backward() opt_disc.step() # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z)) disc_fake = disc(fake) #l2_loss = mse(fake, high_res) adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake)) loss_for_vgg = 0.006 * vgg_loss(fake, high_res) gen_loss = loss_for_vgg + adversarial_loss opt_gen.zero_grad() gen_loss.backward() opt_gen.step() if idx % 200 == 0: plot_examples("test_images/", gen) def main(): dataset = MyImageFolder(root_dir="new_data/") loader = DataLoader( dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=config.NUM_WORKERS, ) gen = Generator(in_channels=3).to(config.DEVICE) disc = Discriminator(img_channels=3).to(config.DEVICE) opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999)) opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999)) mse = nn.MSELoss() bce = nn.BCEWithLogitsLoss() vgg_loss = VGGLoss() if config.LOAD_MODEL: load_checkpoint( config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE, ) load_checkpoint( config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE, ) for epoch in range(config.NUM_EPOCHS): train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss) if config.SAVE_MODEL: save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN) save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC) if __name__ == "__main__": main()