cyclegan, progan
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.idea/
|
||||
ML/Pytorch/more_advanced/image_captioning/flickr8k/
|
||||
ML/algorithms/svm/__pycache__/utils.cpython-38.pyc
|
||||
__pycache__/
|
||||
|
Before Width: | Height: | Size: 40 KiB |
|
Before Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 410 KiB |
@@ -4,17 +4,17 @@ A clean, simple and readable implementation of CycleGAN in PyTorch. I've tried t
|
||||
## Results
|
||||
The model was trained on Zebra<->Horses dataset.
|
||||
|
||||
|1st row: Input / 2nd row: Generated / 3rd row: Target|
|
||||
|1st column: Input / 2nd column: Generated / 3rd row: Re-converted|
|
||||
|:---:|
|
||||
||
|
||||
||
|
||||
||
|
||||
||
|
||||
|
||||
|
||||
### Horses and Zebras Dataset
|
||||
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/suyashdamle/cyclegan).
|
||||
|
||||
### Download pretrained weights
|
||||
Pretrained weights for Satellite image to Google Map [will upload soon]().
|
||||
Pretrained weights [will upload soon]().
|
||||
|
||||
Extract the zip file and put the pth.tar files in the directory with all the python files. Make sure you put LOAD_MODEL=True in the config.py file.
|
||||
|
||||
|
||||
BIN
ML/Pytorch/GANs/CycleGAN/results/horse_results.png
Normal file
|
After Width: | Height: | Size: 817 KiB |
BIN
ML/Pytorch/GANs/CycleGAN/results/zebra_results.png
Normal file
|
After Width: | Height: | Size: 816 KiB |
39
ML/Pytorch/GANs/ProGAN/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# ProGAN
|
||||
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
|
||||
|
||||
## Results
|
||||
The model was trained on the Maps dataset and for fun I also tried using it to colorize anime.
|
||||
|
||||
||
|
||||
|:---:|
|
||||
||
|
||||
||
|
||||
|
||||
|
||||
### Celeb-HQ dataset
|
||||
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/lamsimon/celebahq).
|
||||
|
||||
|
||||
### Download pretrained weights
|
||||
Pretrained weights [here]().
|
||||
|
||||
Extract the zip file and put the pth.tar files in the directory with all the python files. Make sure you put LOAD_MODEL=True in the config.py file.
|
||||
|
||||
### Training
|
||||
Edit the config.py file to match the setup you want to use. Then run train.py
|
||||
|
||||
## ProGAN paper
|
||||
### Progressive Growing of GANs for Improved Quality, Stability, and Variation by Tero Karras, Timo Aila, Samuli Laine, Jaakko Lehtinen
|
||||
|
||||
#### Abstract
|
||||
We investigate conditional adversarial networks as a general-purpose solution to image-to-image translation problems. These networks not only learn the mapping from input image to output image, but also learn a loss function to train this mapping. This makes it possible to apply the same generic approach to problems that traditionally would require very different loss formulations. We demonstrate that this approach is effective at synthesizing photos from label maps, reconstructing objects from edge maps, and colorizing images, among other tasks. Indeed, since the release of the pix2pix software associated with this paper, a large number of internet users (many of them artists) have posted their own experiments with our system, further demonstrating its wide applicability and ease of adoption without the need for parameter tweaking. As a community, we no longer hand-engineer our mapping functions, and this work suggests we can achieve reasonable results without hand-engineering our loss functions either.
|
||||
```
|
||||
@misc{karras2018progressive,
|
||||
title={Progressive Growing of GANs for Improved Quality, Stability, and Variation},
|
||||
author={Tero Karras and Timo Aila and Samuli Laine and Jaakko Lehtinen},
|
||||
year={2018},
|
||||
eprint={1710.10196},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.NE}
|
||||
}
|
||||
```
|
||||
@@ -22,7 +22,7 @@ so specifically the first 5 layers the channels stay the same,
|
||||
whereas when we increase the img_size (towards the later layers)
|
||||
we decrease the number of chanels by 1/2, 1/4, etc.
|
||||
"""
|
||||
factors = [1, 1, 1, 1, 1/2, 1/4, 1/4, 1/8, 1/16]
|
||||
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
|
||||
|
||||
|
||||
class WSConv2d(nn.Module):
|
||||
@@ -31,7 +31,7 @@ class WSConv2d(nn.Module):
|
||||
Note that input is multiplied rather than changing weights
|
||||
this will have the same result.
|
||||
|
||||
Inspired by:
|
||||
Inspired and looked at:
|
||||
https://github.com/nvnbny/progressive_growing_of_gans/blob/master/modelUtils.py
|
||||
"""
|
||||
|
||||
@@ -39,17 +39,17 @@ class WSConv2d(nn.Module):
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
|
||||
):
|
||||
super(WSConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding
|
||||
)
|
||||
self.scale = (gain / (self.conv.weight[0].numel())) ** 0.5
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
|
||||
self.bias = self.conv.bias
|
||||
self.conv.bias = None
|
||||
|
||||
# initialize conv layer
|
||||
nn.init.normal_(self.conv.weight)
|
||||
nn.init.zeros_(self.conv.bias)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x * self.scale)
|
||||
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
@@ -58,9 +58,7 @@ class PixelNorm(nn.Module):
|
||||
self.epsilon = 1e-8
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(
|
||||
torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon
|
||||
)
|
||||
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
@@ -81,42 +79,48 @@ class ConvBlock(nn.Module):
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, z_dim, in_channels, img_size, img_channels=3):
|
||||
def __init__(self, z_dim, in_channels, img_channels=3):
|
||||
super(Generator, self).__init__()
|
||||
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
||||
|
||||
# initial takes 1x1 -> 4x4
|
||||
self.initial = nn.Sequential(
|
||||
PixelNorm(),
|
||||
nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
|
||||
nn.LeakyReLU(0.2),
|
||||
WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
PixelNorm(),
|
||||
)
|
||||
|
||||
# Create progression blocks and rgb layers
|
||||
channels = in_channels
|
||||
self.initial_rgb = WSConv2d(
|
||||
in_channels, img_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.prog_blocks, self.rgb_layers = (
|
||||
nn.ModuleList([]),
|
||||
nn.ModuleList([self.initial_rgb]),
|
||||
)
|
||||
|
||||
# we need to double img for log2(img_size/4) and
|
||||
# +1 in loop for initial 4x4
|
||||
for idx in range(int(log2(img_size/4)) + 1):
|
||||
conv_in = channels
|
||||
conv_out = int(in_channels*factors[idx])
|
||||
self.prog_blocks.append(ConvBlock(conv_in, conv_out))
|
||||
self.rgb_layers.append(WSConv2d(conv_out, img_channels, kernel_size=1, stride=1, padding=0))
|
||||
channels = conv_out
|
||||
for i in range(
|
||||
len(factors) - 1
|
||||
): # -1 to prevent index error because of factors[i+1]
|
||||
conv_in_c = int(in_channels * factors[i])
|
||||
conv_out_c = int(in_channels * factors[i + 1])
|
||||
self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
|
||||
self.rgb_layers.append(
|
||||
WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def fade_in(self, alpha, upscaled, generated):
|
||||
#assert 0 <= alpha <= 1, "Alpha not between 0 and 1"
|
||||
#assert upscaled.shape == generated.shape
|
||||
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
|
||||
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
|
||||
|
||||
def forward(self, x, alpha, steps):
|
||||
upscaled = self.initial(x)
|
||||
out = self.prog_blocks[0](upscaled)
|
||||
out = self.initial(x)
|
||||
|
||||
if steps == 0:
|
||||
return self.rgb_layers[0](out)
|
||||
return self.initial_rgb(out)
|
||||
|
||||
for step in range(1, steps+1):
|
||||
for step in range(steps):
|
||||
upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
|
||||
out = self.prog_blocks[step](upscaled)
|
||||
|
||||
@@ -130,76 +134,101 @@ class Generator(nn.Module):
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, img_size, z_dim, in_channels, img_channels=3):
|
||||
def __init__(self, z_dim, in_channels, img_channels=3):
|
||||
super(Discriminator, self).__init__()
|
||||
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
||||
self.leaky = nn.LeakyReLU(0.2)
|
||||
|
||||
# Create progression blocks and rgb layers
|
||||
channels = in_channels
|
||||
for idx in range(int(log2(img_size/4)) + 1):
|
||||
conv_in = int(in_channels * factors[idx])
|
||||
conv_out = channels
|
||||
self.rgb_layers.append(WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0))
|
||||
# here we work back ways from factors because the discriminator
|
||||
# should be mirrored from the generator. So the first prog_block and
|
||||
# rgb layer we append will work for input size 1024x1024, then 512->256-> etc
|
||||
for i in range(len(factors) - 1, 0, -1):
|
||||
conv_in = int(in_channels * factors[i])
|
||||
conv_out = int(in_channels * factors[i - 1])
|
||||
self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
|
||||
channels = conv_in
|
||||
self.rgb_layers.append(
|
||||
WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
# +1 to in_channels because we concatenate from minibatch std
|
||||
self.conv = WSConv2d(in_channels + 1, z_dim, kernel_size=4, stride=1, padding=0)
|
||||
self.linear = nn.Linear(z_dim, 1)
|
||||
# perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
|
||||
# did this to "mirror" the generator initial_rgb
|
||||
self.initial_rgb = WSConv2d(
|
||||
img_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.rgb_layers.append(self.initial_rgb)
|
||||
self.avg_pool = nn.AvgPool2d(
|
||||
kernel_size=2, stride=2
|
||||
) # down sampling using avg pool
|
||||
|
||||
# this is the block for 4x4 input size
|
||||
self.final_block = nn.Sequential(
|
||||
# +1 to in_channels because we concatenate from MiniBatch std
|
||||
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
WSConv2d(
|
||||
in_channels, 1, kernel_size=1, padding=0, stride=1
|
||||
), # we use this instead of linear layer
|
||||
)
|
||||
|
||||
def fade_in(self, alpha, downscaled, out):
|
||||
"""Used to fade in downscaled using avg pooling and output from CNN"""
|
||||
#assert 0 <= alpha <= 1, "Alpha needs to be between [0, 1]"
|
||||
#assert downscaled.shape == out.shape
|
||||
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
|
||||
return alpha * out + (1 - alpha) * downscaled
|
||||
|
||||
def minibatch_std(self, x):
|
||||
batch_statistics = (
|
||||
torch.std(x, dim=0)
|
||||
.mean()
|
||||
.repeat(x.shape[0], 1, x.shape[2], x.shape[3])
|
||||
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
|
||||
)
|
||||
# we take the std for each example (across all channels, and pixels) then we repeat it
|
||||
# for a single channel and concatenate it with the image. In this way the discriminator
|
||||
# will get information about the variation in the batch/image
|
||||
return torch.cat([x, batch_statistics], dim=1)
|
||||
|
||||
def forward(self, x, alpha, steps):
|
||||
out = self.rgb_layers[steps](x) # convert from rgb as initial step
|
||||
# where we should start in the list of prog_blocks, maybe a bit confusing but
|
||||
# the last is for the 4x4. So example let's say steps=1, then we should start
|
||||
# at the second to last because input_size will be 8x8. If steps==0 we just
|
||||
# use the final block
|
||||
cur_step = len(self.prog_blocks) - steps
|
||||
|
||||
# convert from rgb as initial step, this will depend on
|
||||
# the image size (each will have it's on rgb layer)
|
||||
out = self.leaky(self.rgb_layers[cur_step](x))
|
||||
|
||||
if steps == 0: # i.e, image is 4x4
|
||||
out = self.minibatch_std(out)
|
||||
out = self.conv(out)
|
||||
return self.linear(out.view(-1, out.shape[1]))
|
||||
return self.final_block(out).view(out.shape[0], -1)
|
||||
|
||||
# index steps which has the "reverse" fade_in
|
||||
downscaled = self.rgb_layers[steps - 1](self.avg_pool(x))
|
||||
out = self.avg_pool(self.prog_blocks[steps](out))
|
||||
# because prog_blocks might change the channels, for down scale we use rgb_layer
|
||||
# from previous/smaller size which in our case correlates to +1 in the indexing
|
||||
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
|
||||
out = self.avg_pool(self.prog_blocks[cur_step](out))
|
||||
|
||||
# the fade_in is done first between the downscaled and the input
|
||||
# this is opposite from the generator
|
||||
out = self.fade_in(alpha, downscaled, out)
|
||||
|
||||
for step in range(steps - 1, 0, -1):
|
||||
downscaled = self.avg_pool(out)
|
||||
out = self.prog_blocks[step](downscaled)
|
||||
for step in range(cur_step + 1, len(self.prog_blocks)):
|
||||
out = self.prog_blocks[step](out)
|
||||
out = self.avg_pool(out)
|
||||
|
||||
out = self.minibatch_std(out)
|
||||
out = self.conv(out)
|
||||
return self.linear(out.view(-1, out.shape[1]))
|
||||
return self.final_block(out).view(out.shape[0], -1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
Z_DIM = 100
|
||||
IN_CHANNELS = 16
|
||||
img_size = 512
|
||||
IN_CHANNELS = 256
|
||||
gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
|
||||
critic = Discriminator(Z_DIM, IN_CHANNELS, img_channels=3)
|
||||
|
||||
for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
||||
num_steps = int(log2(img_size / 4))
|
||||
x = torch.randn((5, Z_DIM, 1, 1))
|
||||
gen = Generator(Z_DIM, IN_CHANNELS, img_size=img_size)
|
||||
disc = Discriminator(img_size, Z_DIM, IN_CHANNELS)
|
||||
start = time.time()
|
||||
with torch.autograd.profiler.profile(use_cuda=True) as prof:
|
||||
z = gen(x, alpha=0.5, steps=num_steps)
|
||||
print(prof)
|
||||
gen_time = time.time()-start
|
||||
t = time.time()
|
||||
out = disc(z, 0.01, num_steps)
|
||||
disc_time = time.time()-t
|
||||
print(gen_time, disc_time)
|
||||
#print(disc(z, 0.01, num_steps).shape)
|
||||
x = torch.randn((1, Z_DIM, 1, 1))
|
||||
z = gen(x, 0.5, steps=num_steps)
|
||||
assert z.shape == (1, 3, img_size, img_size)
|
||||
out = critic(z, alpha=0.5, steps=num_steps)
|
||||
assert out.shape == (1, 1)
|
||||
print(f"Success! At img size: {img_size}")
|
||||
|
||||
BIN
ML/Pytorch/GANs/ProGAN/results/64_examples.png
Normal file
|
After Width: | Height: | Size: 6.1 MiB |
BIN
ML/Pytorch/GANs/ProGAN/results/result1.png
Normal file
|
After Width: | Height: | Size: 365 KiB |
@@ -1,5 +1,4 @@
|
||||
def func(x=1, y=2, **kwargs):
|
||||
print(x, y)
|
||||
it = iter(l)
|
||||
|
||||
|
||||
print(func(x=3, y=4))
|
||||
for el in it:
|
||||
print(el, next(it))
|
||||
|
||||
@@ -1,54 +1,50 @@
|
||||
""" Training of ProGAN using WGAN-GP loss"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
|
||||
from utils import (
|
||||
gradient_penalty,
|
||||
plot_to_tensorboard,
|
||||
save_checkpoint,
|
||||
load_checkpoint,
|
||||
generate_examples,
|
||||
)
|
||||
from model import Discriminator, Generator
|
||||
from math import log2
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import config
|
||||
|
||||
torch.backends.cudnn.benchmarks = True
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Hyperparameters etc.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LEARNING_RATE = 1e-4
|
||||
BATCH_SIZES = [128, 128, 64, 16, 8, 4, 2, 2, 1]
|
||||
IMAGE_SIZE = 128
|
||||
CHANNELS_IMG = 3
|
||||
Z_DIM = 128
|
||||
IN_CHANNELS = 128
|
||||
CRITIC_ITERATIONS = 1
|
||||
LAMBDA_GP = 10
|
||||
NUM_STEPS = int(log2(IMAGE_SIZE / 4)) + 1
|
||||
PROGRESSIVE_EPOCHS = [2 ** i for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
|
||||
PROGRESSIVE_EPOCHS = [8 for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
|
||||
fixed_noise = torch.randn(8, Z_DIM, 1, 1).to(device)
|
||||
NUM_WORKERS = 4
|
||||
|
||||
def get_loader(image_size):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(CHANNELS_IMG)],
|
||||
[0.5 for _ in range(CHANNELS_IMG)],
|
||||
[0.5 for _ in range(config.CHANNELS_IMG)],
|
||||
[0.5 for _ in range(config.CHANNELS_IMG)],
|
||||
),
|
||||
]
|
||||
)
|
||||
batch_size = BATCH_SIZES[int(log2(image_size/4))]
|
||||
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
|
||||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
|
||||
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
|
||||
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader, dataset
|
||||
|
||||
|
||||
def train_fn(
|
||||
critic,
|
||||
gen,
|
||||
@@ -60,85 +56,119 @@ def train_fn(
|
||||
opt_gen,
|
||||
tensorboard_step,
|
||||
writer,
|
||||
scaler_gen,
|
||||
scaler_critic,
|
||||
):
|
||||
start = time.time()
|
||||
total_time = 0
|
||||
training = tqdm(loader, leave=True)
|
||||
for batch_idx, (real, _) in enumerate(training):
|
||||
real = real.to(device)
|
||||
loop = tqdm(loader, leave=True)
|
||||
# critic_losses = []
|
||||
reals = 0
|
||||
fakes = 0
|
||||
for batch_idx, (real, _) in enumerate(loop):
|
||||
real = real.to(config.DEVICE)
|
||||
cur_batch_size = real.shape[0]
|
||||
model_start = time.time()
|
||||
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)]
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
|
||||
# which is equivalent to minimizing the negative of the expression
|
||||
for _ in range(CRITIC_ITERATIONS):
|
||||
critic.zero_grad()
|
||||
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
|
||||
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
fake = gen(noise, alpha, step)
|
||||
critic_real = critic(real, alpha, step).reshape(-1)
|
||||
critic_fake = critic(fake, alpha, step).reshape(-1)
|
||||
gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
|
||||
critic_real = critic(real, alpha, step)
|
||||
critic_fake = critic(fake.detach(), alpha, step)
|
||||
reals += critic_real.mean().item()
|
||||
fakes += critic_fake.mean().item()
|
||||
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
+ LAMBDA_GP * gp
|
||||
+ config.LAMBDA_GP * gp
|
||||
+ (0.001 * torch.mean(critic_real ** 2))
|
||||
)
|
||||
loss_critic.backward(retain_graph=True)
|
||||
opt_critic.step()
|
||||
|
||||
opt_critic.zero_grad()
|
||||
scaler_critic.scale(loss_critic).backward()
|
||||
scaler_critic.step(opt_critic)
|
||||
scaler_critic.update()
|
||||
|
||||
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
|
||||
gen.zero_grad()
|
||||
fake = gen(noise, alpha, step)
|
||||
gen_fake = critic(fake, alpha, step).reshape(-1)
|
||||
with torch.cuda.amp.autocast():
|
||||
gen_fake = critic(fake, alpha, step)
|
||||
loss_gen = -torch.mean(gen_fake)
|
||||
loss_gen.backward()
|
||||
opt_gen.step()
|
||||
|
||||
opt_gen.zero_grad()
|
||||
scaler_gen.scale(loss_gen).backward()
|
||||
scaler_gen.step(opt_gen)
|
||||
scaler_gen.update()
|
||||
|
||||
# Update alpha and ensure less than 1
|
||||
alpha += cur_batch_size / (
|
||||
(PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
|
||||
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
|
||||
)
|
||||
alpha = min(alpha, 1)
|
||||
total_time += time.time()-model_start
|
||||
|
||||
if batch_idx % 300 == 0:
|
||||
if batch_idx % 500 == 0:
|
||||
with torch.no_grad():
|
||||
fixed_fakes = gen(fixed_noise, alpha, step)
|
||||
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
|
||||
plot_to_tensorboard(
|
||||
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
|
||||
writer,
|
||||
loss_critic.item(),
|
||||
loss_gen.item(),
|
||||
real.detach(),
|
||||
fixed_fakes.detach(),
|
||||
tensorboard_step,
|
||||
)
|
||||
tensorboard_step += 1
|
||||
|
||||
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
|
||||
loop.set_postfix(
|
||||
reals=reals / (batch_idx + 1),
|
||||
fakes=fakes / (batch_idx + 1),
|
||||
gp=gp.item(),
|
||||
loss_critic=loss_critic.item(),
|
||||
)
|
||||
|
||||
return tensorboard_step, alpha
|
||||
|
||||
|
||||
def main():
|
||||
# initialize gen and disc, note: discriminator should be called critic,
|
||||
# according to WGAN paper (since it no longer outputs between [0, 1])
|
||||
gen = Generator(Z_DIM, IN_CHANNELS, img_size=IMAGE_SIZE, img_channels=CHANNELS_IMG).to(device)
|
||||
critic = Discriminator(IMAGE_SIZE, Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(device)
|
||||
# but really who cares..
|
||||
gen = Generator(
|
||||
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
critic = Discriminator(
|
||||
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
|
||||
# initializate optimizer
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
|
||||
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
|
||||
# initialize optimizers and scalers for FP16 training
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
|
||||
opt_critic = optim.Adam(
|
||||
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
|
||||
)
|
||||
scaler_critic = torch.cuda.amp.GradScaler()
|
||||
scaler_gen = torch.cuda.amp.GradScaler()
|
||||
|
||||
# for tensorboard plotting
|
||||
writer = SummaryWriter(f"logs/gan")
|
||||
writer = SummaryWriter(f"logs/gan1")
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
load_checkpoint(torch.load("celeba_wgan_gp.pth.tar"), gen, critic)
|
||||
gen.train()
|
||||
critic.train()
|
||||
|
||||
tensorboard_step = 0
|
||||
for step, num_epochs in enumerate(PROGRESSIVE_EPOCHS):
|
||||
alpha = 0.01
|
||||
if step < 3:
|
||||
continue
|
||||
# start at step that corresponds to img size that we set in config
|
||||
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
|
||||
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
|
||||
alpha = 1e-5 # start with very low alpha
|
||||
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
|
||||
print(f"Current image size: {4 * 2 ** step}")
|
||||
|
||||
if step == 4:
|
||||
print(f"Img size is: {4*2**step}")
|
||||
|
||||
loader, dataset = get_loader(4 * 2 ** step)
|
||||
for epoch in range(num_epochs):
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}]")
|
||||
tensorboard_step, alpha = train_fn(
|
||||
@@ -152,14 +182,16 @@ def main():
|
||||
opt_gen,
|
||||
tensorboard_step,
|
||||
writer,
|
||||
scaler_gen,
|
||||
scaler_critic,
|
||||
)
|
||||
|
||||
checkpoint = {'gen': gen.state_dict(),
|
||||
'critic': critic.state_dict(),
|
||||
'opt_gen': opt_gen.state_dict(),
|
||||
'opt_critic': opt_critic.state_dict()}
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
|
||||
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
|
||||
|
||||
step += 1 # progress to the next img size
|
||||
|
||||
save_checkpoint(checkpoint)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,54 +0,0 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
def plot_to_tensorboard(
|
||||
writer, loss_critic, loss_gen, real, fake, tensorboard_step
|
||||
):
|
||||
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
|
||||
|
||||
with torch.no_grad():
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
|
||||
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
|
||||
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * beta + fake * (1 - beta)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images, alpha, train_step)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(state, filename="celeba_wgan_gp.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
torch.save(state, filename)
|
||||
|
||||
def load_checkpoint(checkpoint, gen, disc, opt_gen=None, opt_disc=None):
|
||||
print("=> Loading checkpoint")
|
||||
gen.load_state_dict(checkpoint['gen'])
|
||||
disc.load_state_dict(checkpoint['critic'])
|
||||
|
||||
if opt_gen != None and opt_disc != None:
|
||||
opt_gen.load_state_dict(checkpoint['opt_gen'])
|
||||
opt_disc.load_state_dict(checkpoint['opt_critic'])
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
Put images in images folder, text files for labels in labels folder.
|
||||
Then under COCO put train.csv, and test.csv
|
||||
@@ -1,2 +0,0 @@
|
||||
Put images in images folder, text files for labels in labels folder.
|
||||
Then under PASCAL_VOC put train.csv, and test.csv
|
||||
@@ -11,11 +11,11 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
NUM_WORKERS = 4
|
||||
BATCH_SIZE = 32
|
||||
IMAGE_SIZE = 416
|
||||
NUM_CLASSES = 80
|
||||
LEARNING_RATE = 3e-5
|
||||
NUM_CLASSES = 20
|
||||
LEARNING_RATE = 1e-5
|
||||
WEIGHT_DECAY = 1e-4
|
||||
NUM_EPOCHS = 100
|
||||
CONF_THRESHOLD = 0.6
|
||||
CONF_THRESHOLD = 0.05
|
||||
MAP_IOU_THRESH = 0.5
|
||||
NMS_IOU_THRESH = 0.45
|
||||
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
|
||||
@@ -47,9 +47,9 @@ train_transforms = A.Compose(
|
||||
A.OneOf(
|
||||
[
|
||||
A.ShiftScaleRotate(
|
||||
rotate_limit=10, p=0.4, border_mode=cv2.BORDER_CONSTANT
|
||||
rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
|
||||
),
|
||||
A.IAAAffine(shear=10, p=0.4, mode="constant"),
|
||||
A.IAAAffine(shear=15, p=0.5, mode="constant"),
|
||||
],
|
||||
p=1.0,
|
||||
),
|
||||
|
||||
@@ -44,7 +44,7 @@ class YoloLoss(nn.Module):
|
||||
anchors = anchors.reshape(1, 3, 1, 1, 2)
|
||||
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
|
||||
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
|
||||
object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj]))
|
||||
object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
|
||||
|
||||
# ======================== #
|
||||
# FOR BOX COORDINATES #
|
||||
|
||||
@@ -19,6 +19,8 @@ from utils import (
|
||||
plot_couple_examples
|
||||
)
|
||||
from loss import YoloLoss
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
@@ -80,19 +82,16 @@ def main():
|
||||
#plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
|
||||
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
|
||||
#if config.SAVE_MODEL:
|
||||
# save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
|
||||
|
||||
#print(f"Currently epoch {epoch}")
|
||||
#print("On Train Eval loader:")
|
||||
#check_class_accuracy(model, train_eval_loader, threshold=config.CONF_THRESHOLD)
|
||||
#print("On Train loader:")
|
||||
#check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)
|
||||
|
||||
if epoch % 10 == 0 and epoch > 0:
|
||||
print("On Test loader:")
|
||||
if epoch > 0 and epoch % 3 == 0:
|
||||
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
|
||||
|
||||
pred_boxes, true_boxes = get_evaluation_bboxes(
|
||||
test_loader,
|
||||
model,
|
||||
@@ -108,7 +107,7 @@ def main():
|
||||
num_classes=config.NUM_CLASSES,
|
||||
)
|
||||
print(f"MAP: {mapval.item()}")
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -380,8 +380,6 @@ def check_class_accuracy(model, loader, threshold):
|
||||
tot_obj, correct_obj = 0, 0
|
||||
|
||||
for idx, (x, y) in enumerate(tqdm(loader)):
|
||||
if idx == 100:
|
||||
break
|
||||
x = x.to(config.DEVICE)
|
||||
with torch.no_grad():
|
||||
out = model(x)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Download and put pretrained weights here!
|
||||