mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
158 lines
5.5 KiB
Python
158 lines
5.5 KiB
Python
import torch
|
|
from dataset import HorseZebraDataset
|
|
import sys
|
|
from utils import save_checkpoint, load_checkpoint
|
|
from torch.utils.data import DataLoader
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import config
|
|
from tqdm import tqdm
|
|
from torchvision.utils import save_image
|
|
from discriminator_model import Discriminator
|
|
from generator_model import Generator
|
|
|
|
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
|
|
H_reals = 0
|
|
H_fakes = 0
|
|
loop = tqdm(loader, leave=True)
|
|
|
|
for idx, (zebra, horse) in enumerate(loop):
|
|
zebra = zebra.to(config.DEVICE)
|
|
horse = horse.to(config.DEVICE)
|
|
|
|
# Train Discriminators H and Z
|
|
with torch.cuda.amp.autocast():
|
|
fake_horse = gen_H(zebra)
|
|
D_H_real = disc_H(horse)
|
|
D_H_fake = disc_H(fake_horse.detach())
|
|
H_reals += D_H_real.mean().item()
|
|
H_fakes += D_H_fake.mean().item()
|
|
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
|
|
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
|
|
D_H_loss = D_H_real_loss + D_H_fake_loss
|
|
|
|
fake_zebra = gen_Z(horse)
|
|
D_Z_real = disc_Z(zebra)
|
|
D_Z_fake = disc_Z(fake_zebra.detach())
|
|
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
|
|
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
|
|
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
|
|
|
# put it togethor
|
|
D_loss = (D_H_loss + D_Z_loss)/2
|
|
|
|
opt_disc.zero_grad()
|
|
d_scaler.scale(D_loss).backward()
|
|
d_scaler.step(opt_disc)
|
|
d_scaler.update()
|
|
|
|
# Train Generators H and Z
|
|
with torch.cuda.amp.autocast():
|
|
# adversarial loss for both generators
|
|
D_H_fake = disc_H(fake_horse)
|
|
D_Z_fake = disc_Z(fake_zebra)
|
|
loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
|
|
loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
|
|
|
|
# cycle loss
|
|
cycle_zebra = gen_Z(fake_horse)
|
|
cycle_horse = gen_H(fake_zebra)
|
|
cycle_zebra_loss = l1(zebra, cycle_zebra)
|
|
cycle_horse_loss = l1(horse, cycle_horse)
|
|
|
|
# identity loss (remove these for efficiency if you set lambda_identity=0)
|
|
identity_zebra = gen_Z(zebra)
|
|
identity_horse = gen_H(horse)
|
|
identity_zebra_loss = l1(zebra, identity_zebra)
|
|
identity_horse_loss = l1(horse, identity_horse)
|
|
|
|
# add all togethor
|
|
G_loss = (
|
|
loss_G_Z
|
|
+ loss_G_H
|
|
+ cycle_zebra_loss * config.LAMBDA_CYCLE
|
|
+ cycle_horse_loss * config.LAMBDA_CYCLE
|
|
+ identity_horse_loss * config.LAMBDA_IDENTITY
|
|
+ identity_zebra_loss * config.LAMBDA_IDENTITY
|
|
)
|
|
|
|
opt_gen.zero_grad()
|
|
g_scaler.scale(G_loss).backward()
|
|
g_scaler.step(opt_gen)
|
|
g_scaler.update()
|
|
|
|
if idx % 200 == 0:
|
|
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
|
|
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
|
|
|
|
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
|
|
|
|
|
|
|
|
def main():
|
|
disc_H = Discriminator(in_channels=3).to(config.DEVICE)
|
|
disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
|
|
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
|
|
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
|
|
opt_disc = optim.Adam(
|
|
list(disc_H.parameters()) + list(disc_Z.parameters()),
|
|
lr=config.LEARNING_RATE,
|
|
betas=(0.5, 0.999),
|
|
)
|
|
|
|
opt_gen = optim.Adam(
|
|
list(gen_Z.parameters()) + list(gen_H.parameters()),
|
|
lr=config.LEARNING_RATE,
|
|
betas=(0.5, 0.999),
|
|
)
|
|
|
|
L1 = nn.L1Loss()
|
|
mse = nn.MSELoss()
|
|
|
|
if config.LOAD_MODEL:
|
|
load_checkpoint(
|
|
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
|
|
)
|
|
load_checkpoint(
|
|
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
|
|
)
|
|
load_checkpoint(
|
|
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
|
|
)
|
|
load_checkpoint(
|
|
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
|
|
)
|
|
|
|
dataset = HorseZebraDataset(
|
|
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
|
|
)
|
|
val_dataset = HorseZebraDataset(
|
|
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
|
|
)
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
pin_memory=True,
|
|
)
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=config.BATCH_SIZE,
|
|
shuffle=True,
|
|
num_workers=config.NUM_WORKERS,
|
|
pin_memory=True
|
|
)
|
|
g_scaler = torch.cuda.amp.GradScaler()
|
|
d_scaler = torch.cuda.amp.GradScaler()
|
|
|
|
for epoch in range(config.NUM_EPOCHS):
|
|
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
|
|
|
|
if config.SAVE_MODEL:
|
|
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
|
|
save_checkpoint(gen_Z, opt_gen, filename=config.CHECKPOINT_GEN_Z)
|
|
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
|
|
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
|
|
|
|
if __name__ == "__main__":
|
|
main() |