This commit is contained in:
Aladdin Persson
2021-03-06 21:09:08 +01:00
parent 00ea9fea1f
commit 2a397b17e2
14 changed files with 425 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
# CycleGAN
A clean, simple and readable implementation of CycleGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
## Results
The model was trained on Zebra<->Horses dataset.
|1st row: Input / 2nd row: Generated / 3rd row: Target|
|:---:|
|![](results/results_anime.png)|
|![](results/results_maps.png)|
### Horses and Zebras Dataset
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/suyashdamle/cyclegan).
### Download pretrained weights
Pretrained weights for Satellite image to Google Map [will upload soon]().
Extract the zip file and put the pth.tar files in the directory with all the python files. Make sure you put LOAD_MODEL=True in the config.py file.
### Training
Edit the config.py file to match the setup you want to use. Then run train.py
## CycleGAN paper
### Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros
#### Abstract
Image-to-image translation is a class of vision and graphics problems where the goal is to learn the mapping between an input image and an output image using a training set of aligned image pairs. However, for many tasks, paired training data will not be available. We present an approach for learning to translate an image from a source domain X to a target domain Y in the absence of paired examples. Our goal is to learn a mapping G:X→Y such that the distribution of images from G(X) is indistinguishable from the distribution Y using an adversarial loss. Because this mapping is highly under-constrained, we couple it with an inverse mapping F:Y→X and introduce a cycle consistency loss to push F(G(X))≈X (and vice versa). Qualitative results are presented on several tasks where paired training data does not exist, including collection style transfer, object transfiguration, season transfer, photo enhancement, etc. Quantitative comparisons against several prior methods demonstrate the superiority of our approach. ```
@misc{zhu2020unpaired,
title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
author={Jun-Yan Zhu and Taesung Park and Phillip Isola and Alexei A. Efros},
year={2020},
eprint={1703.10593},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

View File

@@ -0,0 +1,29 @@
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/train"
VAL_DIR = "data/val"
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 10
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"
transforms = A.Compose(
[
A.Resize(width=256, height=256),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
additional_targets={"image0": "image"},
)

View File

@@ -0,0 +1,41 @@
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np
class HorseZebraDataset(Dataset):
def __init__(self, root_zebra, root_horse, transform=None):
self.root_zebra = root_zebra
self.root_horse = root_horse
self.transform = transform
self.zebra_images = os.listdir(root_zebra)
self.horse_images = os.listdir(root_horse)
self.length_dataset = max(len(self.zebra_images), len(self.horse_images)) # 1000, 1500
self.zebra_len = len(self.zebra_images)
self.horse_len = len(self.horse_images)
def __len__(self):
return self.length_dataset
def __getitem__(self, index):
zebra_img = self.zebra_images[index % self.zebra_len]
horse_img = self.horse_images[index % self.horse_len]
zebra_path = os.path.join(self.root_zebra, zebra_img)
horse_path = os.path.join(self.root_horse, horse_img)
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
horse_img = np.array(Image.open(horse_path).convert("RGB"))
if self.transform:
augmentations = self.transform(image=zebra_img, image0=horse_img)
zebra_img = augmentations["image"]
horse_img = augmentations["image0"]
return zebra_img, horse_img

View File

@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, x):
return self.conv(x)
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(
in_channels,
features[0],
kernel_size=4,
stride=2,
padding=1,
padding_mode="reflect",
),
nn.LeakyReLU(0.2, inplace=True),
)
layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
in_channels = feature
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
self.model = nn.Sequential(*layers)
def forward(self, x):
x = self.initial(x)
return torch.sigmoid(self.model(x))
def test():
x = torch.randn((5, 3, 256, 256))
model = Discriminator(in_channels=3)
preds = model(x)
print(preds.shape)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity()
)
def forward(self, x):
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
ConvBlock(channels, channels, kernel_size=3, padding=1),
ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, img_channels, num_features = 64, num_residuals=9):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True),
)
self.down_blocks = nn.ModuleList(
[
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
]
)
self.res_blocks = nn.Sequential(
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
)
self.up_blocks = nn.ModuleList(
[
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
]
)
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
def forward(self, x):
x = self.initial(x)
for layer in self.down_blocks:
x = layer(x)
x = self.res_blocks(x)
for layer in self.up_blocks:
x = layer(x)
return torch.tanh(self.last(x))
def test():
img_channels = 3
img_size = 256
x = torch.randn((2, img_channels, img_size, img_size))
gen = Generator(img_channels, 9)
print(gen(x).shape)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,158 @@
import torch
from dataset import HorseZebraDataset
import sys
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import config
from tqdm import tqdm
from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
H_reals = 0
H_fakes = 0
loop = tqdm(loader, leave=True)
for idx, (zebra, horse) in enumerate(loop):
zebra = zebra.to(config.DEVICE)
horse = horse.to(config.DEVICE)
# Train Discriminators H and Z
with torch.cuda.amp.autocast():
fake_horse = gen_H(zebra)
D_H_real = disc_H(horse)
D_H_fake = disc_H(fake_horse.detach())
H_reals += D_H_real.mean().item()
H_fakes += D_H_fake.mean().item()
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
D_H_loss = D_H_real_loss + D_H_fake_loss
fake_zebra = gen_Z(horse)
D_Z_real = disc_Z(zebra)
D_Z_fake = disc_Z(fake_zebra.detach())
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
# put it togethor
D_loss = (D_H_loss + D_Z_loss)/2
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()
# Train Generators H and Z
with torch.cuda.amp.autocast():
# adversarial loss for both generators
D_H_fake = disc_H(fake_horse)
D_Z_fake = disc_Z(fake_zebra)
loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
# cycle loss
cycle_zebra = gen_Z(fake_horse)
cycle_horse = gen_H(fake_zebra)
cycle_zebra_loss = l1(zebra, cycle_zebra)
cycle_horse_loss = l1(horse, cycle_horse)
# identity loss (remove these for efficiency if you set lambda_identity=0)
identity_zebra = gen_Z(zebra)
identity_horse = gen_H(horse)
identity_zebra_loss = l1(zebra, identity_zebra)
identity_horse_loss = l1(horse, identity_horse)
# add all togethor
G_loss = (
loss_G_Z
+ loss_G_H
+ cycle_zebra_loss * config.LAMBDA_CYCLE
+ cycle_horse_loss * config.LAMBDA_CYCLE
+ identity_horse_loss * config.LAMBDA_IDENTITY
+ identity_zebra_loss * config.LAMBDA_IDENTITY
)
opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
if idx % 200 == 0:
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
def main():
disc_H = Discriminator(in_channels=3).to(config.DEVICE)
disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
opt_disc = optim.Adam(
list(disc_H.parameters()) + list(disc_Z.parameters()),
lr=config.LEARNING_RATE,
betas=(0.5, 0.999),
)
opt_gen = optim.Adam(
list(gen_Z.parameters()) + list(gen_H.parameters()),
lr=config.LEARNING_RATE,
betas=(0.5, 0.999),
)
L1 = nn.L1Loss()
mse = nn.MSELoss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
)
dataset = HorseZebraDataset(
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
)
val_dataset = HorseZebraDataset(
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
)
loader = DataLoader(
dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(config.NUM_EPOCHS):
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
if config.SAVE_MODEL:
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
save_checkpoint(gen_Z, opt_gen, filename=config.CHECKPOINT_GEN_Z)
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,35 @@
import random, torch, os, numpy as np
import torch.nn as nn
import config
import copy
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def seed_everything(seed=42):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False