mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 19:27:58 +00:00
155 lines
4.2 KiB
Python
155 lines
4.2 KiB
Python
import torch
|
|
import config
|
|
from torch import nn
|
|
from torch import optim
|
|
from utils import gradient_penalty, load_checkpoint, save_checkpoint, plot_examples
|
|
from loss import VGGLoss
|
|
from torch.utils.data import DataLoader
|
|
from model import Generator, Discriminator, initialize_weights
|
|
from tqdm import tqdm
|
|
from dataset import MyImageFolder
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
def train_fn(
|
|
loader,
|
|
disc,
|
|
gen,
|
|
opt_gen,
|
|
opt_disc,
|
|
l1,
|
|
vgg_loss,
|
|
g_scaler,
|
|
d_scaler,
|
|
writer,
|
|
tb_step,
|
|
):
|
|
loop = tqdm(loader, leave=True)
|
|
|
|
for idx, (low_res, high_res) in enumerate(loop):
|
|
high_res = high_res.to(config.DEVICE)
|
|
low_res = low_res.to(config.DEVICE)
|
|
|
|
with torch.cuda.amp.autocast():
|
|
fake = gen(low_res)
|
|
critic_real = disc(high_res)
|
|
critic_fake = disc(fake.detach())
|
|
gp = gradient_penalty(disc, high_res, fake, device=config.DEVICE)
|
|
loss_critic = (
|
|
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
|
+ config.LAMBDA_GP * gp
|
|
)
|
|
|
|
opt_disc.zero_grad()
|
|
d_scaler.scale(loss_critic).backward()
|
|
d_scaler.step(opt_disc)
|
|
d_scaler.update()
|
|
|
|
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
|
|
with torch.cuda.amp.autocast():
|
|
l1_loss = 1e-2 * l1(fake, high_res)
|
|
adversarial_loss = 5e-3 * -torch.mean(disc(fake))
|
|
loss_for_vgg = vgg_loss(fake, high_res)
|
|
gen_loss = l1_loss + loss_for_vgg + adversarial_loss
|
|
|
|
opt_gen.zero_grad()
|
|
g_scaler.scale(gen_loss).backward()
|
|
g_scaler.step(opt_gen)
|
|
g_scaler.update()
|
|
|
|
writer.add_scalar("Critic loss", loss_critic.item(), global_step=tb_step)
|
|
tb_step += 1
|
|
|
|
if idx % 100 == 0 and idx > 0:
|
|
plot_examples("test_images/", gen)
|
|
|
|
loop.set_postfix(
|
|
gp=gp.item(),
|
|
critic=loss_critic.item(),
|
|
l1=l1_loss.item(),
|
|
vgg=loss_for_vgg.item(),
|
|
adversarial=adversarial_loss.item(),
|
|
)
|
|
|
|
return tb_step
|
|
|
|
|
|
def main():
|
|
dataset = MyImageFolder(root_dir="data/")
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=config.BATCH_SIZE,
|
|
shuffle=True,
|
|
pin_memory=True,
|
|
num_workers=config.NUM_WORKERS,
|
|
)
|
|
gen = Generator(in_channels=3).to(config.DEVICE)
|
|
disc = Discriminator(in_channels=3).to(config.DEVICE)
|
|
initialize_weights(gen)
|
|
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
|
|
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
|
|
writer = SummaryWriter("logs")
|
|
tb_step = 0
|
|
l1 = nn.L1Loss()
|
|
gen.train()
|
|
disc.train()
|
|
vgg_loss = VGGLoss()
|
|
|
|
g_scaler = torch.cuda.amp.GradScaler()
|
|
d_scaler = torch.cuda.amp.GradScaler()
|
|
|
|
if config.LOAD_MODEL:
|
|
load_checkpoint(
|
|
config.CHECKPOINT_GEN,
|
|
gen,
|
|
opt_gen,
|
|
config.LEARNING_RATE,
|
|
)
|
|
load_checkpoint(
|
|
config.CHECKPOINT_DISC,
|
|
disc,
|
|
opt_disc,
|
|
config.LEARNING_RATE,
|
|
)
|
|
|
|
|
|
for epoch in range(config.NUM_EPOCHS):
|
|
tb_step = train_fn(
|
|
loader,
|
|
disc,
|
|
gen,
|
|
opt_gen,
|
|
opt_disc,
|
|
l1,
|
|
vgg_loss,
|
|
g_scaler,
|
|
d_scaler,
|
|
writer,
|
|
tb_step,
|
|
)
|
|
|
|
if config.SAVE_MODEL:
|
|
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
|
|
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try_model = True
|
|
|
|
if try_model:
|
|
# Will just use pretrained weights and run on images
|
|
# in test_images/ and save the ones to SR in saved/
|
|
gen = Generator(in_channels=3).to(config.DEVICE)
|
|
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
|
|
load_checkpoint(
|
|
config.CHECKPOINT_GEN,
|
|
gen,
|
|
opt_gen,
|
|
config.LEARNING_RATE,
|
|
)
|
|
plot_examples("test_images/", gen)
|
|
else:
|
|
# This will train from scratch
|
|
main()
|