mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
cyclegan
This commit is contained in:
37
ML/Pytorch/GANs/CycleGAN/README.md
Normal file
37
ML/Pytorch/GANs/CycleGAN/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# CycleGAN
|
||||
A clean, simple and readable implementation of CycleGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
|
||||
|
||||
## Results
|
||||
The model was trained on Zebra<->Horses dataset.
|
||||
|
||||
|1st row: Input / 2nd row: Generated / 3rd row: Target|
|
||||
|:---:|
|
||||
||
|
||||
||
|
||||
|
||||
|
||||
### Horses and Zebras Dataset
|
||||
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/suyashdamle/cyclegan).
|
||||
|
||||
### Download pretrained weights
|
||||
Pretrained weights for Satellite image to Google Map [will upload soon]().
|
||||
|
||||
Extract the zip file and put the pth.tar files in the directory with all the python files. Make sure you put LOAD_MODEL=True in the config.py file.
|
||||
|
||||
### Training
|
||||
Edit the config.py file to match the setup you want to use. Then run train.py
|
||||
|
||||
## CycleGAN paper
|
||||
### Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros
|
||||
|
||||
#### Abstract
|
||||
Image-to-image translation is a class of vision and graphics problems where the goal is to learn the mapping between an input image and an output image using a training set of aligned image pairs. However, for many tasks, paired training data will not be available. We present an approach for learning to translate an image from a source domain X to a target domain Y in the absence of paired examples. Our goal is to learn a mapping G:X→Y such that the distribution of images from G(X) is indistinguishable from the distribution Y using an adversarial loss. Because this mapping is highly under-constrained, we couple it with an inverse mapping F:Y→X and introduce a cycle consistency loss to push F(G(X))≈X (and vice versa). Qualitative results are presented on several tasks where paired training data does not exist, including collection style transfer, object transfiguration, season transfer, photo enhancement, etc. Quantitative comparisons against several prior methods demonstrate the superiority of our approach. ```
|
||||
@misc{zhu2020unpaired,
|
||||
title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
|
||||
author={Jun-Yan Zhu and Taesung Park and Phillip Isola and Alexei A. Efros},
|
||||
year={2020},
|
||||
eprint={1703.10593},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
||||
29
ML/Pytorch/GANs/CycleGAN/config.py
Normal file
29
ML/Pytorch/GANs/CycleGAN/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
TRAIN_DIR = "data/train"
|
||||
VAL_DIR = "data/val"
|
||||
BATCH_SIZE = 1
|
||||
LEARNING_RATE = 1e-5
|
||||
LAMBDA_IDENTITY = 0.0
|
||||
LAMBDA_CYCLE = 10
|
||||
NUM_WORKERS = 4
|
||||
NUM_EPOCHS = 10
|
||||
LOAD_MODEL = True
|
||||
SAVE_MODEL = True
|
||||
CHECKPOINT_GEN_H = "genh.pth.tar"
|
||||
CHECKPOINT_GEN_Z = "genz.pth.tar"
|
||||
CHECKPOINT_CRITIC_H = "critich.pth.tar"
|
||||
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"
|
||||
|
||||
transforms = A.Compose(
|
||||
[
|
||||
A.Resize(width=256, height=256),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
additional_targets={"image0": "image"},
|
||||
)
|
||||
41
ML/Pytorch/GANs/CycleGAN/dataset.py
Normal file
41
ML/Pytorch/GANs/CycleGAN/dataset.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from PIL import Image
|
||||
import os
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
|
||||
class HorseZebraDataset(Dataset):
|
||||
def __init__(self, root_zebra, root_horse, transform=None):
|
||||
self.root_zebra = root_zebra
|
||||
self.root_horse = root_horse
|
||||
self.transform = transform
|
||||
|
||||
self.zebra_images = os.listdir(root_zebra)
|
||||
self.horse_images = os.listdir(root_horse)
|
||||
self.length_dataset = max(len(self.zebra_images), len(self.horse_images)) # 1000, 1500
|
||||
self.zebra_len = len(self.zebra_images)
|
||||
self.horse_len = len(self.horse_images)
|
||||
|
||||
def __len__(self):
|
||||
return self.length_dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
zebra_img = self.zebra_images[index % self.zebra_len]
|
||||
horse_img = self.horse_images[index % self.horse_len]
|
||||
|
||||
zebra_path = os.path.join(self.root_zebra, zebra_img)
|
||||
horse_path = os.path.join(self.root_horse, horse_img)
|
||||
|
||||
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
|
||||
horse_img = np.array(Image.open(horse_path).convert("RGB"))
|
||||
|
||||
if self.transform:
|
||||
augmentations = self.transform(image=zebra_img, image0=horse_img)
|
||||
zebra_img = augmentations["image"]
|
||||
horse_img = augmentations["image0"]
|
||||
|
||||
return zebra_img, horse_img
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
53
ML/Pytorch/GANs/CycleGAN/discriminator_model.py
Normal file
53
ML/Pytorch/GANs/CycleGAN/discriminator_model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
|
||||
super().__init__()
|
||||
self.initial = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
features[0],
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
padding_mode="reflect",
|
||||
),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
|
||||
layers = []
|
||||
in_channels = features[0]
|
||||
for feature in features[1:]:
|
||||
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
|
||||
in_channels = feature
|
||||
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial(x)
|
||||
return torch.sigmoid(self.model(x))
|
||||
|
||||
def test():
|
||||
x = torch.randn((5, 3, 256, 256))
|
||||
model = Discriminator(in_channels=3)
|
||||
preds = model(x)
|
||||
print(preds.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
72
ML/Pytorch/GANs/CycleGAN/generator_model.py
Normal file
72
ML/Pytorch/GANs/CycleGAN/generator_model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
|
||||
if down
|
||||
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True) if use_act else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
ConvBlock(channels, channels, kernel_size=3, padding=1),
|
||||
ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.block(x)
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, img_channels, num_features = 64, num_residuals=9):
|
||||
super().__init__()
|
||||
self.initial = nn.Sequential(
|
||||
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
|
||||
nn.InstanceNorm2d(num_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.down_blocks = nn.ModuleList(
|
||||
[
|
||||
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
|
||||
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
|
||||
]
|
||||
)
|
||||
self.res_blocks = nn.Sequential(
|
||||
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
|
||||
)
|
||||
self.up_blocks = nn.ModuleList(
|
||||
[
|
||||
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
]
|
||||
)
|
||||
|
||||
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial(x)
|
||||
for layer in self.down_blocks:
|
||||
x = layer(x)
|
||||
x = self.res_blocks(x)
|
||||
for layer in self.up_blocks:
|
||||
x = layer(x)
|
||||
return torch.tanh(self.last(x))
|
||||
|
||||
def test():
|
||||
img_channels = 3
|
||||
img_size = 256
|
||||
x = torch.randn((2, img_channels, img_size, img_size))
|
||||
gen = Generator(img_channels, 9)
|
||||
print(gen(x).shape)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
158
ML/Pytorch/GANs/CycleGAN/train.py
Normal file
158
ML/Pytorch/GANs/CycleGAN/train.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import torch
|
||||
from dataset import HorseZebraDataset
|
||||
import sys
|
||||
from utils import save_checkpoint, load_checkpoint
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import config
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import save_image
|
||||
from discriminator_model import Discriminator
|
||||
from generator_model import Generator
|
||||
|
||||
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
|
||||
H_reals = 0
|
||||
H_fakes = 0
|
||||
loop = tqdm(loader, leave=True)
|
||||
|
||||
for idx, (zebra, horse) in enumerate(loop):
|
||||
zebra = zebra.to(config.DEVICE)
|
||||
horse = horse.to(config.DEVICE)
|
||||
|
||||
# Train Discriminators H and Z
|
||||
with torch.cuda.amp.autocast():
|
||||
fake_horse = gen_H(zebra)
|
||||
D_H_real = disc_H(horse)
|
||||
D_H_fake = disc_H(fake_horse.detach())
|
||||
H_reals += D_H_real.mean().item()
|
||||
H_fakes += D_H_fake.mean().item()
|
||||
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
|
||||
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
|
||||
D_H_loss = D_H_real_loss + D_H_fake_loss
|
||||
|
||||
fake_zebra = gen_Z(horse)
|
||||
D_Z_real = disc_Z(zebra)
|
||||
D_Z_fake = disc_Z(fake_zebra.detach())
|
||||
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
|
||||
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
|
||||
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
||||
|
||||
# put it togethor
|
||||
D_loss = (D_H_loss + D_Z_loss)/2
|
||||
|
||||
opt_disc.zero_grad()
|
||||
d_scaler.scale(D_loss).backward()
|
||||
d_scaler.step(opt_disc)
|
||||
d_scaler.update()
|
||||
|
||||
# Train Generators H and Z
|
||||
with torch.cuda.amp.autocast():
|
||||
# adversarial loss for both generators
|
||||
D_H_fake = disc_H(fake_horse)
|
||||
D_Z_fake = disc_Z(fake_zebra)
|
||||
loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
|
||||
loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
|
||||
|
||||
# cycle loss
|
||||
cycle_zebra = gen_Z(fake_horse)
|
||||
cycle_horse = gen_H(fake_zebra)
|
||||
cycle_zebra_loss = l1(zebra, cycle_zebra)
|
||||
cycle_horse_loss = l1(horse, cycle_horse)
|
||||
|
||||
# identity loss (remove these for efficiency if you set lambda_identity=0)
|
||||
identity_zebra = gen_Z(zebra)
|
||||
identity_horse = gen_H(horse)
|
||||
identity_zebra_loss = l1(zebra, identity_zebra)
|
||||
identity_horse_loss = l1(horse, identity_horse)
|
||||
|
||||
# add all togethor
|
||||
G_loss = (
|
||||
loss_G_Z
|
||||
+ loss_G_H
|
||||
+ cycle_zebra_loss * config.LAMBDA_CYCLE
|
||||
+ cycle_horse_loss * config.LAMBDA_CYCLE
|
||||
+ identity_horse_loss * config.LAMBDA_IDENTITY
|
||||
+ identity_zebra_loss * config.LAMBDA_IDENTITY
|
||||
)
|
||||
|
||||
opt_gen.zero_grad()
|
||||
g_scaler.scale(G_loss).backward()
|
||||
g_scaler.step(opt_gen)
|
||||
g_scaler.update()
|
||||
|
||||
if idx % 200 == 0:
|
||||
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
|
||||
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
|
||||
|
||||
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
disc_H = Discriminator(in_channels=3).to(config.DEVICE)
|
||||
disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
|
||||
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
|
||||
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
|
||||
opt_disc = optim.Adam(
|
||||
list(disc_H.parameters()) + list(disc_Z.parameters()),
|
||||
lr=config.LEARNING_RATE,
|
||||
betas=(0.5, 0.999),
|
||||
)
|
||||
|
||||
opt_gen = optim.Adam(
|
||||
list(gen_Z.parameters()) + list(gen_H.parameters()),
|
||||
lr=config.LEARNING_RATE,
|
||||
betas=(0.5, 0.999),
|
||||
)
|
||||
|
||||
L1 = nn.L1Loss()
|
||||
mse = nn.MSELoss()
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
dataset = HorseZebraDataset(
|
||||
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
|
||||
)
|
||||
val_dataset = HorseZebraDataset(
|
||||
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
pin_memory=True
|
||||
)
|
||||
g_scaler = torch.cuda.amp.GradScaler()
|
||||
d_scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
for epoch in range(config.NUM_EPOCHS):
|
||||
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
|
||||
save_checkpoint(gen_Z, opt_gen, filename=config.CHECKPOINT_GEN_Z)
|
||||
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
|
||||
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
35
ML/Pytorch/GANs/CycleGAN/utils.py
Normal file
35
ML/Pytorch/GANs/CycleGAN/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import random, torch, os, numpy as np
|
||||
import torch.nn as nn
|
||||
import config
|
||||
import copy
|
||||
|
||||
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||||
print("=> Loading checkpoint")
|
||||
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
# If we don't do this then it will just have learning rate of old checkpoint
|
||||
# and it will lead to many hours of debugging \:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def seed_everything(seed=42):
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user