stylegan, esrgan, srgan code
BIN
ML/Pytorch/GANs/ESRGAN/ESRGAN.png
Normal file
|
After Width: | Height: | Size: 131 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/ESRGAN_generator.pth
Normal file
48
ML/Pytorch/GANs/ESRGAN/config.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
LOAD_MODEL = True
|
||||
SAVE_MODEL = True
|
||||
CHECKPOINT_GEN = "gen.pth"
|
||||
CHECKPOINT_DISC = "disc.pth"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LEARNING_RATE = 1e-4
|
||||
NUM_EPOCHS = 10000
|
||||
BATCH_SIZE = 16
|
||||
LAMBDA_GP = 10
|
||||
NUM_WORKERS = 4
|
||||
HIGH_RES = 128
|
||||
LOW_RES = HIGH_RES // 4
|
||||
IMG_CHANNELS = 3
|
||||
|
||||
highres_transform = A.Compose(
|
||||
[
|
||||
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
lowres_transform = A.Compose(
|
||||
[
|
||||
A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
|
||||
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
both_transforms = A.Compose(
|
||||
[
|
||||
A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.RandomRotate90(p=0.5),
|
||||
]
|
||||
)
|
||||
|
||||
test_transform = A.Compose(
|
||||
[
|
||||
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
49
ML/Pytorch/GANs/ESRGAN/dataset.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import torch.nn
|
||||
import os
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
import config
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
|
||||
class MyImageFolder(Dataset):
|
||||
def __init__(self, root_dir):
|
||||
super(MyImageFolder, self).__init__()
|
||||
self.data = []
|
||||
self.root_dir = root_dir
|
||||
self.class_names = os.listdir(root_dir)
|
||||
|
||||
for index, name in enumerate(self.class_names):
|
||||
files = os.listdir(os.path.join(root_dir, name))
|
||||
self.data += list(zip(files, [index] * len(files)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_file, label = self.data[index]
|
||||
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
|
||||
|
||||
image = cv2.imread(os.path.join(root_and_dir, img_file))
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
both_transform = config.both_transforms(image=image)["image"]
|
||||
low_res = config.lowres_transform(image=both_transform)["image"]
|
||||
high_res = config.highres_transform(image=both_transform)["image"]
|
||||
return low_res, high_res
|
||||
|
||||
|
||||
def test():
|
||||
dataset = MyImageFolder(root_dir="data/")
|
||||
loader = DataLoader(dataset, batch_size=8)
|
||||
|
||||
for low_res, high_res in loader:
|
||||
print(low_res.shape)
|
||||
print(high_res.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
19
ML/Pytorch/GANs/ESRGAN/loss.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch.nn as nn
|
||||
from torchvision.models import vgg19
|
||||
import config
|
||||
|
||||
|
||||
class VGGLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vgg = vgg19(pretrained=True).features[:35].eval().to(config.DEVICE)
|
||||
|
||||
for param in self.vgg.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.loss = nn.MSELoss()
|
||||
|
||||
def forward(self, input, target):
|
||||
vgg_input_features = self.vgg(input)
|
||||
vgg_target_features = self.vgg(target)
|
||||
return self.loss(vgg_input_features, vgg_target_features)
|
||||
154
ML/Pytorch/GANs/ESRGAN/model.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, use_act, **kwargs):
|
||||
super().__init__()
|
||||
self.cnn = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
**kwargs,
|
||||
bias=True,
|
||||
)
|
||||
self.act = nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.cnn(x))
|
||||
|
||||
|
||||
class UpsampleBlock(nn.Module):
|
||||
def __init__(self, in_c, scale_factor=2):
|
||||
super().__init__()
|
||||
self.upsample = nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
||||
self.conv = nn.Conv2d(in_c, in_c, 3, 1, 1, bias=True)
|
||||
self.act = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(self.upsample(x)))
|
||||
|
||||
|
||||
class DenseResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, channels=32, residual_beta=0.2):
|
||||
super().__init__()
|
||||
self.residual_beta = residual_beta
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for i in range(5):
|
||||
self.blocks.append(
|
||||
ConvBlock(
|
||||
in_channels + channels * i,
|
||||
channels if i <= 3 else in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
use_act=True if i <= 3 else False,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
new_inputs = x
|
||||
for block in self.blocks:
|
||||
out = block(new_inputs)
|
||||
new_inputs = torch.cat([new_inputs, out], dim=1)
|
||||
return self.residual_beta * out + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
def __init__(self, in_channels, residual_beta=0.2):
|
||||
super().__init__()
|
||||
self.residual_beta = residual_beta
|
||||
self.rrdb = nn.Sequential(*[DenseResidualBlock(in_channels) for _ in range(3)])
|
||||
|
||||
def forward(self, x):
|
||||
return self.rrdb(x) * self.residual_beta + x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels=3, num_channels=64, num_blocks=23):
|
||||
super().__init__()
|
||||
self.initial = nn.Conv2d(
|
||||
in_channels,
|
||||
num_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True,
|
||||
)
|
||||
self.residuals = nn.Sequential(*[RRDB(num_channels) for _ in range(num_blocks)])
|
||||
self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.upsamples = nn.Sequential(
|
||||
UpsampleBlock(num_channels), UpsampleBlock(num_channels),
|
||||
)
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(num_channels, in_channels, 3, 1, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
initial = self.initial(x)
|
||||
x = self.conv(self.residuals(initial)) + initial
|
||||
x = self.upsamples(x)
|
||||
return self.final(x)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
|
||||
super().__init__()
|
||||
blocks = []
|
||||
for idx, feature in enumerate(features):
|
||||
blocks.append(
|
||||
ConvBlock(
|
||||
in_channels,
|
||||
feature,
|
||||
kernel_size=3,
|
||||
stride=1 + idx % 2,
|
||||
padding=1,
|
||||
use_act=True,
|
||||
),
|
||||
)
|
||||
in_channels = feature
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((6, 6)),
|
||||
nn.Flatten(),
|
||||
nn.Linear(512 * 6 * 6, 1024),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Linear(1024, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
return self.classifier(x)
|
||||
|
||||
def initialize_weights(model, scale=0.1):
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight.data)
|
||||
m.weight.data *= scale
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight.data)
|
||||
m.weight.data *= scale
|
||||
|
||||
|
||||
def test():
|
||||
gen = Generator()
|
||||
disc = Discriminator()
|
||||
low_res = 24
|
||||
x = torch.randn((5, 3, low_res, low_res))
|
||||
gen_out = gen(x)
|
||||
disc_out = disc(gen_out)
|
||||
|
||||
print(gen_out.shape)
|
||||
print(disc_out.shape)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
BIN
ML/Pytorch/GANs/ESRGAN/saved/baboon_LR.png
Normal file
|
After Width: | Height: | Size: 543 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/saved/baby_LR.png
Normal file
|
After Width: | Height: | Size: 375 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/saved/butterfly_LR.png
Normal file
|
After Width: | Height: | Size: 124 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/saved/comic_LR.png
Normal file
|
After Width: | Height: | Size: 199 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/saved/head_LR.png
Normal file
|
After Width: | Height: | Size: 137 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/saved/woman_LR.png
Normal file
|
After Width: | Height: | Size: 123 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/test_images/baboon_LR.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/test_images/baby_LR.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/test_images/butterfly_LR.png
Normal file
|
After Width: | Height: | Size: 10 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/test_images/comic_LR.png
Normal file
|
After Width: | Height: | Size: 14 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/test_images/head_LR.png
Normal file
|
After Width: | Height: | Size: 8.9 KiB |
BIN
ML/Pytorch/GANs/ESRGAN/test_images/woman_LR.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
154
ML/Pytorch/GANs/ESRGAN/train.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import config
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
from utils import gradient_penalty, load_checkpoint, save_checkpoint, plot_examples
|
||||
from loss import VGGLoss
|
||||
from torch.utils.data import DataLoader
|
||||
from model import Generator, Discriminator, initialize_weights
|
||||
from tqdm import tqdm
|
||||
from dataset import MyImageFolder
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def train_fn(
|
||||
loader,
|
||||
disc,
|
||||
gen,
|
||||
opt_gen,
|
||||
opt_disc,
|
||||
l1,
|
||||
vgg_loss,
|
||||
g_scaler,
|
||||
d_scaler,
|
||||
writer,
|
||||
tb_step,
|
||||
):
|
||||
loop = tqdm(loader, leave=True)
|
||||
|
||||
for idx, (low_res, high_res) in enumerate(loop):
|
||||
high_res = high_res.to(config.DEVICE)
|
||||
low_res = low_res.to(config.DEVICE)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
fake = gen(low_res)
|
||||
critic_real = disc(high_res)
|
||||
critic_fake = disc(fake.detach())
|
||||
gp = gradient_penalty(disc, high_res, fake, device=config.DEVICE)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
+ config.LAMBDA_GP * gp
|
||||
)
|
||||
|
||||
opt_disc.zero_grad()
|
||||
d_scaler.scale(loss_critic).backward()
|
||||
d_scaler.step(opt_disc)
|
||||
d_scaler.update()
|
||||
|
||||
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
|
||||
with torch.cuda.amp.autocast():
|
||||
l1_loss = 1e-2 * l1(fake, high_res)
|
||||
adversarial_loss = 5e-3 * -torch.mean(disc(fake))
|
||||
loss_for_vgg = vgg_loss(fake, high_res)
|
||||
gen_loss = l1_loss + loss_for_vgg + adversarial_loss
|
||||
|
||||
opt_gen.zero_grad()
|
||||
g_scaler.scale(gen_loss).backward()
|
||||
g_scaler.step(opt_gen)
|
||||
g_scaler.update()
|
||||
|
||||
writer.add_scalar("Critic loss", loss_critic.item(), global_step=tb_step)
|
||||
tb_step += 1
|
||||
|
||||
if idx % 100 == 0 and idx > 0:
|
||||
plot_examples("test_images/", gen)
|
||||
|
||||
loop.set_postfix(
|
||||
gp=gp.item(),
|
||||
critic=loss_critic.item(),
|
||||
l1=l1_loss.item(),
|
||||
vgg=loss_for_vgg.item(),
|
||||
adversarial=adversarial_loss.item(),
|
||||
)
|
||||
|
||||
return tb_step
|
||||
|
||||
|
||||
def main():
|
||||
dataset = MyImageFolder(root_dir="data/")
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.BATCH_SIZE,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
)
|
||||
gen = Generator(in_channels=3).to(config.DEVICE)
|
||||
disc = Discriminator(in_channels=3).to(config.DEVICE)
|
||||
initialize_weights(gen)
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
|
||||
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
|
||||
writer = SummaryWriter("logs")
|
||||
tb_step = 0
|
||||
l1 = nn.L1Loss()
|
||||
gen.train()
|
||||
disc.train()
|
||||
vgg_loss = VGGLoss()
|
||||
|
||||
g_scaler = torch.cuda.amp.GradScaler()
|
||||
d_scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN,
|
||||
gen,
|
||||
opt_gen,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_DISC,
|
||||
disc,
|
||||
opt_disc,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
|
||||
for epoch in range(config.NUM_EPOCHS):
|
||||
tb_step = train_fn(
|
||||
loader,
|
||||
disc,
|
||||
gen,
|
||||
opt_gen,
|
||||
opt_disc,
|
||||
l1,
|
||||
vgg_loss,
|
||||
g_scaler,
|
||||
d_scaler,
|
||||
writer,
|
||||
tb_step,
|
||||
)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
|
||||
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try_model = True
|
||||
|
||||
if try_model:
|
||||
# Will just use pretrained weights and run on images
|
||||
# in test_images/ and save the ones to SR in saved/
|
||||
gen = Generator(in_channels=3).to(config.DEVICE)
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN,
|
||||
gen,
|
||||
opt_gen,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
plot_examples("test_images/", gen)
|
||||
else:
|
||||
# This will train from scratch
|
||||
main()
|
||||
67
ML/Pytorch/GANs/ESRGAN/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import os
|
||||
import config
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.utils import save_image
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, device):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * alpha + fake.detach() * (1 - alpha)
|
||||
interpolated_images.requires_grad_(True)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||||
print("=> Loading checkpoint")
|
||||
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
|
||||
# model.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
# If we don't do this then it will just have learning rate of old checkpoint
|
||||
# and it will lead to many hours of debugging \:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def plot_examples(low_res_folder, gen):
|
||||
files = os.listdir(low_res_folder)
|
||||
|
||||
gen.eval()
|
||||
for file in files:
|
||||
image = Image.open("test_images/" + file)
|
||||
with torch.no_grad():
|
||||
upscaled_img = gen(
|
||||
config.test_transform(image=np.asarray(image))["image"]
|
||||
.unsqueeze(0)
|
||||
.to(config.DEVICE)
|
||||
)
|
||||
save_image(upscaled_img, f"saved/{file}")
|
||||
gen.train()
|
||||
BIN
ML/Pytorch/GANs/SRGAN/architecture.png
Normal file
|
After Width: | Height: | Size: 1013 KiB |
47
ML/Pytorch/GANs/SRGAN/config.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
LOAD_MODEL = True
|
||||
SAVE_MODEL = True
|
||||
CHECKPOINT_GEN = "gen.pth.tar"
|
||||
CHECKPOINT_DISC = "disc.pth.tar"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LEARNING_RATE = 1e-4
|
||||
NUM_EPOCHS = 100
|
||||
BATCH_SIZE = 16
|
||||
NUM_WORKERS = 4
|
||||
HIGH_RES = 96
|
||||
LOW_RES = HIGH_RES // 4
|
||||
IMG_CHANNELS = 3
|
||||
|
||||
highres_transform = A.Compose(
|
||||
[
|
||||
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
lowres_transform = A.Compose(
|
||||
[
|
||||
A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
|
||||
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
|
||||
both_transforms = A.Compose(
|
||||
[
|
||||
A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.RandomRotate90(p=0.5),
|
||||
]
|
||||
)
|
||||
|
||||
test_transform = A.Compose(
|
||||
[
|
||||
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
43
ML/Pytorch/GANs/SRGAN/dataset.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import config
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class MyImageFolder(Dataset):
|
||||
def __init__(self, root_dir):
|
||||
super(MyImageFolder, self).__init__()
|
||||
self.data = []
|
||||
self.root_dir = root_dir
|
||||
self.class_names = os.listdir(root_dir)
|
||||
|
||||
for index, name in enumerate(self.class_names):
|
||||
files = os.listdir(os.path.join(root_dir, name))
|
||||
self.data += list(zip(files, [index] * len(files)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_file, label = self.data[index]
|
||||
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
|
||||
|
||||
image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
|
||||
image = config.both_transforms(image=image)["image"]
|
||||
high_res = config.highres_transform(image=image)["image"]
|
||||
low_res = config.lowres_transform(image=image)["image"]
|
||||
return low_res, high_res
|
||||
|
||||
|
||||
def test():
|
||||
dataset = MyImageFolder(root_dir="new_data/")
|
||||
loader = DataLoader(dataset, batch_size=1, num_workers=8)
|
||||
|
||||
for low_res, high_res in loader:
|
||||
print(low_res.shape)
|
||||
print(high_res.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
21
ML/Pytorch/GANs/SRGAN/loss.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch.nn as nn
|
||||
from torchvision.models import vgg19
|
||||
import config
|
||||
|
||||
# phi_5,4 5th conv layer before maxpooling but after activation
|
||||
|
||||
class VGGLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE)
|
||||
self.loss = nn.MSELoss()
|
||||
|
||||
for param in self.vgg.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, input, target):
|
||||
vgg_input_features = self.vgg(input)
|
||||
vgg_target_features = self.vgg(target)
|
||||
return self.loss(vgg_input_features, vgg_target_features)
|
||||
|
||||
|
||||
128
ML/Pytorch/GANs/SRGAN/model.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
discriminator=False,
|
||||
use_act=True,
|
||||
use_bn=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_act = use_act
|
||||
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
|
||||
self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
|
||||
self.act = (
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
if discriminator
|
||||
else nn.PReLU(num_parameters=out_channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
|
||||
|
||||
|
||||
class UpsampleBlock(nn.Module):
|
||||
def __init__(self, in_c, scale_factor):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
|
||||
self.ps = nn.PixelShuffle(scale_factor) # in_c * 4, H, W --> in_c, H*2, W*2
|
||||
self.act = nn.PReLU(num_parameters=in_c)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.ps(self.conv(x)))
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.block1 = ConvBlock(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1
|
||||
)
|
||||
self.block2 = ConvBlock(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
use_act=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.block1(x)
|
||||
out = self.block2(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
|
||||
super().__init__()
|
||||
self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
|
||||
self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
|
||||
self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
|
||||
self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
|
||||
self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)
|
||||
|
||||
def forward(self, x):
|
||||
initial = self.initial(x)
|
||||
x = self.residuals(initial)
|
||||
x = self.convblock(x) + initial
|
||||
x = self.upsamples(x)
|
||||
return torch.tanh(self.final(x))
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
|
||||
super().__init__()
|
||||
blocks = []
|
||||
for idx, feature in enumerate(features):
|
||||
blocks.append(
|
||||
ConvBlock(
|
||||
in_channels,
|
||||
feature,
|
||||
kernel_size=3,
|
||||
stride=1 + idx % 2,
|
||||
padding=1,
|
||||
discriminator=True,
|
||||
use_act=True,
|
||||
use_bn=False if idx == 0 else True,
|
||||
)
|
||||
)
|
||||
in_channels = feature
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((6, 6)),
|
||||
nn.Flatten(),
|
||||
nn.Linear(512*6*6, 1024),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Linear(1024, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
return self.classifier(x)
|
||||
|
||||
def test():
|
||||
low_resolution = 24 # 96x96 -> 24x24
|
||||
with torch.cuda.amp.autocast():
|
||||
x = torch.randn((5, 3, low_resolution, low_resolution))
|
||||
gen = Generator()
|
||||
gen_out = gen(x)
|
||||
disc = Discriminator()
|
||||
disc_out = disc(gen_out)
|
||||
|
||||
print(gen_out.shape)
|
||||
print(disc_out.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
88
ML/Pytorch/GANs/SRGAN/train.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import config
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
from utils import load_checkpoint, save_checkpoint, plot_examples
|
||||
from loss import VGGLoss
|
||||
from torch.utils.data import DataLoader
|
||||
from model import Generator, Discriminator
|
||||
from tqdm import tqdm
|
||||
from dataset import MyImageFolder
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
|
||||
loop = tqdm(loader, leave=True)
|
||||
|
||||
for idx, (low_res, high_res) in enumerate(loop):
|
||||
high_res = high_res.to(config.DEVICE)
|
||||
low_res = low_res.to(config.DEVICE)
|
||||
|
||||
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
|
||||
fake = gen(low_res)
|
||||
disc_real = disc(high_res)
|
||||
disc_fake = disc(fake.detach())
|
||||
disc_loss_real = bce(
|
||||
disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
|
||||
)
|
||||
disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
|
||||
loss_disc = disc_loss_fake + disc_loss_real
|
||||
|
||||
opt_disc.zero_grad()
|
||||
loss_disc.backward()
|
||||
opt_disc.step()
|
||||
|
||||
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
|
||||
disc_fake = disc(fake)
|
||||
#l2_loss = mse(fake, high_res)
|
||||
adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
|
||||
loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
|
||||
gen_loss = loss_for_vgg + adversarial_loss
|
||||
|
||||
opt_gen.zero_grad()
|
||||
gen_loss.backward()
|
||||
opt_gen.step()
|
||||
|
||||
if idx % 200 == 0:
|
||||
plot_examples("test_images/", gen)
|
||||
|
||||
|
||||
def main():
|
||||
dataset = MyImageFolder(root_dir="new_data/")
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.BATCH_SIZE,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
)
|
||||
gen = Generator(in_channels=3).to(config.DEVICE)
|
||||
disc = Discriminator(img_channels=3).to(config.DEVICE)
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
|
||||
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
|
||||
mse = nn.MSELoss()
|
||||
bce = nn.BCEWithLogitsLoss()
|
||||
vgg_loss = VGGLoss()
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN,
|
||||
gen,
|
||||
opt_gen,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
for epoch in range(config.NUM_EPOCHS):
|
||||
train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
|
||||
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
66
ML/Pytorch/GANs/SRGAN/utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import os
|
||||
import config
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.utils import save_image
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, device):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * alpha + fake.detach() * (1 - alpha)
|
||||
interpolated_images.requires_grad_(True)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||||
print("=> Loading checkpoint")
|
||||
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
# If we don't do this then it will just have learning rate of old checkpoint
|
||||
# and it will lead to many hours of debugging \:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def plot_examples(low_res_folder, gen):
|
||||
files = os.listdir(low_res_folder)
|
||||
|
||||
gen.eval()
|
||||
for file in files:
|
||||
image = Image.open("test_images/" + file)
|
||||
with torch.no_grad():
|
||||
upscaled_img = gen(
|
||||
config.test_transform(image=np.asarray(image))["image"]
|
||||
.unsqueeze(0)
|
||||
.to(config.DEVICE)
|
||||
)
|
||||
save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
|
||||
gen.train()
|
||||
25
ML/Pytorch/GANs/StyleGAN/config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import torch
|
||||
from math import log2
|
||||
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
#from utils import seed_everything
|
||||
|
||||
START_TRAIN_AT_IMG_SIZE = 32
|
||||
DATASET = 'FFHQ_32'
|
||||
CHECKPOINT_GEN = "generator.pth"
|
||||
CHECKPOINT_CRITIC = "critic.pth"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LOAD_MODEL = False
|
||||
SAVE_MODEL = True
|
||||
LEARNING_RATE = 1e-3
|
||||
BATCH_SIZES = [32, 32, 32, 32, 32, 16, 8, 4, 2]
|
||||
CHANNELS_IMG = 3
|
||||
Z_DIM = 512
|
||||
W_DIM = 512
|
||||
IN_CHANNELS = 512
|
||||
LAMBDA_GP = 10
|
||||
PROGRESSIVE_EPOCHS = [50] * 100
|
||||
FIXED_NOISE = torch.randn((8, Z_DIM)).to(DEVICE)
|
||||
NUM_WORKERS = 6
|
||||
9
ML/Pytorch/GANs/StyleGAN/make_resized_data.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
root_dir = "FFHQ/images1024x1024"
|
||||
|
||||
for file in tqdm(os.listdir(root_dir)):
|
||||
img = Image.open(root_dir+ "/"+file).resize((128, 128))
|
||||
img.save("FFHQ_resized/"+file)
|
||||
293
ML/Pytorch/GANs/StyleGAN/model.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from math import log2
|
||||
|
||||
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.epsilon = 1e-8
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
|
||||
|
||||
|
||||
class MappingNetwork(nn.Module):
|
||||
def __init__(self, z_dim, w_dim):
|
||||
super().__init__()
|
||||
self.mapping = nn.Sequential(
|
||||
PixelNorm(),
|
||||
WSLinear(z_dim, w_dim),
|
||||
nn.ReLU(),
|
||||
WSLinear(w_dim, w_dim),
|
||||
nn.ReLU(),
|
||||
WSLinear(w_dim, w_dim),
|
||||
nn.ReLU(),
|
||||
WSLinear(w_dim, w_dim),
|
||||
nn.ReLU(),
|
||||
WSLinear(w_dim, w_dim),
|
||||
nn.ReLU(),
|
||||
WSLinear(w_dim, w_dim),
|
||||
nn.ReLU(),
|
||||
WSLinear(w_dim, w_dim),
|
||||
nn.ReLU(),
|
||||
WSLinear(w_dim, w_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mapping(x)
|
||||
|
||||
|
||||
class InjectNoise(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
|
||||
return x + self.weight * noise
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, channels, w_dim):
|
||||
super().__init__()
|
||||
self.instance_norm = nn.InstanceNorm2d(channels)
|
||||
self.style_scale = WSLinear(w_dim, channels)
|
||||
self.style_bias = WSLinear(w_dim, channels)
|
||||
|
||||
def forward(self, x, w):
|
||||
x = self.instance_norm(x)
|
||||
style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
|
||||
style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
|
||||
return style_scale * x + style_bias
|
||||
|
||||
|
||||
class WSConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2,
|
||||
):
|
||||
super(WSConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
|
||||
self.bias = self.conv.bias
|
||||
self.conv.bias = None
|
||||
|
||||
# initialize conv layer
|
||||
nn.init.normal_(self.conv.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
|
||||
|
||||
|
||||
class WSLinear(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, out_features, gain=2,
|
||||
):
|
||||
super(WSLinear, self).__init__()
|
||||
self.linear = nn.Linear(in_features, out_features)
|
||||
self.scale = (gain / in_features)**0.5
|
||||
self.bias = self.linear.bias
|
||||
self.linear.bias = None
|
||||
|
||||
# initialize linear layer
|
||||
nn.init.normal_(self.linear.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x * self.scale) + self.bias
|
||||
|
||||
|
||||
class GenBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, w_dim):
|
||||
super(GenBlock, self).__init__()
|
||||
self.conv1 = WSConv2d(in_channels, out_channels)
|
||||
self.conv2 = WSConv2d(out_channels, out_channels)
|
||||
self.leaky = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.inject_noise1 = InjectNoise(out_channels)
|
||||
self.inject_noise2 = InjectNoise(out_channels)
|
||||
self.adain1 = AdaIN(out_channels, w_dim)
|
||||
self.adain2 = AdaIN(out_channels, w_dim)
|
||||
|
||||
def forward(self, x, w):
|
||||
x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
|
||||
x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv1 = WSConv2d(in_channels, out_channels)
|
||||
self.conv2 = WSConv2d(out_channels, out_channels)
|
||||
self.leaky = nn.LeakyReLU(0.2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.leaky(self.conv1(x))
|
||||
x = self.leaky(self.conv2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
|
||||
super(Generator, self).__init__()
|
||||
self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
|
||||
self.map = MappingNetwork(z_dim, w_dim)
|
||||
self.initial_adain1 = AdaIN(in_channels, w_dim)
|
||||
self.initial_adain2 = AdaIN(in_channels, w_dim)
|
||||
self.initial_noise1 = InjectNoise(in_channels)
|
||||
self.initial_noise2 = InjectNoise(in_channels)
|
||||
self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.leaky = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
self.initial_rgb = WSConv2d(
|
||||
in_channels, img_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.prog_blocks, self.rgb_layers = (
|
||||
nn.ModuleList([]),
|
||||
nn.ModuleList([self.initial_rgb]),
|
||||
)
|
||||
|
||||
for i in range(len(factors) - 1): # -1 to prevent index error because of factors[i+1]
|
||||
conv_in_c = int(in_channels * factors[i])
|
||||
conv_out_c = int(in_channels * factors[i + 1])
|
||||
self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
|
||||
self.rgb_layers.append(
|
||||
WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def fade_in(self, alpha, upscaled, generated):
|
||||
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
|
||||
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
|
||||
|
||||
def forward(self, noise, alpha, steps):
|
||||
w = self.map(noise)
|
||||
x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
|
||||
x = self.initial_conv(x)
|
||||
out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)
|
||||
|
||||
if steps == 0:
|
||||
return self.initial_rgb(x)
|
||||
|
||||
for step in range(steps):
|
||||
upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
|
||||
out = self.prog_blocks[step](upscaled, w)
|
||||
|
||||
# The number of channels in upscale will stay the same, while
|
||||
# out which has moved through prog_blocks might change. To ensure
|
||||
# we can convert both to rgb we use different rgb_layers
|
||||
# (steps-1) and steps for upscaled, out respectively
|
||||
final_upscaled = self.rgb_layers[steps - 1](upscaled)
|
||||
final_out = self.rgb_layers[steps](out)
|
||||
return self.fade_in(alpha, final_upscaled, final_out)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, img_channels=3):
|
||||
super(Discriminator, self).__init__()
|
||||
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
||||
self.leaky = nn.LeakyReLU(0.2)
|
||||
|
||||
# here we work back ways from factors because the discriminator
|
||||
# should be mirrored from the generator. So the first prog_block and
|
||||
# rgb layer we append will work for input size 1024x1024, then 512->256-> etc
|
||||
for i in range(len(factors) - 1, 0, -1):
|
||||
conv_in = int(in_channels * factors[i])
|
||||
conv_out = int(in_channels * factors[i - 1])
|
||||
self.prog_blocks.append(ConvBlock(conv_in, conv_out))
|
||||
self.rgb_layers.append(
|
||||
WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
# perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
|
||||
# did this to "mirror" the generator initial_rgb
|
||||
self.initial_rgb = WSConv2d(
|
||||
img_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.rgb_layers.append(self.initial_rgb)
|
||||
self.avg_pool = nn.AvgPool2d(
|
||||
kernel_size=2, stride=2
|
||||
) # down sampling using avg pool
|
||||
|
||||
# this is the block for 4x4 input size
|
||||
self.final_block = nn.Sequential(
|
||||
# +1 to in_channels because we concatenate from MiniBatch std
|
||||
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
WSConv2d(
|
||||
in_channels, 1, kernel_size=1, padding=0, stride=1
|
||||
), # we use this instead of linear layer
|
||||
)
|
||||
|
||||
def fade_in(self, alpha, downscaled, out):
|
||||
"""Used to fade in downscaled using avg pooling and output from CNN"""
|
||||
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
|
||||
return alpha * out + (1 - alpha) * downscaled
|
||||
|
||||
def minibatch_std(self, x):
|
||||
batch_statistics = (
|
||||
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
|
||||
)
|
||||
# we take the std for each example (across all channels, and pixels) then we repeat it
|
||||
# for a single channel and concatenate it with the image. In this way the discriminator
|
||||
# will get information about the variation in the batch/image
|
||||
return torch.cat([x, batch_statistics], dim=1)
|
||||
|
||||
def forward(self, x, alpha, steps):
|
||||
# where we should start in the list of prog_blocks, maybe a bit confusing but
|
||||
# the last is for the 4x4. So example let's say steps=1, then we should start
|
||||
# at the second to last because input_size will be 8x8. If steps==0 we just
|
||||
# use the final block
|
||||
cur_step = len(self.prog_blocks) - steps
|
||||
|
||||
# convert from rgb as initial step, this will depend on
|
||||
# the image size (each will have it's on rgb layer)
|
||||
out = self.leaky(self.rgb_layers[cur_step](x))
|
||||
|
||||
if steps == 0: # i.e, image is 4x4
|
||||
out = self.minibatch_std(out)
|
||||
return self.final_block(out).view(out.shape[0], -1)
|
||||
|
||||
# because prog_blocks might change the channels, for down scale we use rgb_layer
|
||||
# from previous/smaller size which in our case correlates to +1 in the indexing
|
||||
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
|
||||
out = self.avg_pool(self.prog_blocks[cur_step](out))
|
||||
|
||||
# the fade_in is done first between the downscaled and the input
|
||||
# this is opposite from the generator
|
||||
out = self.fade_in(alpha, downscaled, out)
|
||||
|
||||
for step in range(cur_step + 1, len(self.prog_blocks)):
|
||||
out = self.prog_blocks[step](out)
|
||||
out = self.avg_pool(out)
|
||||
|
||||
out = self.minibatch_std(out)
|
||||
return self.final_block(out).view(out.shape[0], -1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Z_DIM = 512
|
||||
W_DIM = 512
|
||||
IN_CHANNELS = 512
|
||||
gen = Generator(Z_DIM, W_DIM, IN_CHANNELS, img_channels=3).to("cuda")
|
||||
disc = Discriminator(IN_CHANNELS, img_channels=3).to("cuda")
|
||||
|
||||
tot = 0
|
||||
for param in gen.parameters():
|
||||
tot += param.numel()
|
||||
|
||||
print(tot)
|
||||
import sys
|
||||
sys.exit()
|
||||
|
||||
|
||||
for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
||||
num_steps = int(log2(img_size / 4))
|
||||
x = torch.randn((2, Z_DIM)).to("cuda")
|
||||
z = gen(x, 0.5, steps=num_steps)
|
||||
assert z.shape == (2, 3, img_size, img_size)
|
||||
out = disc(z, alpha=0.5, steps=num_steps)
|
||||
assert out.shape == (2, 1)
|
||||
print(f"Success! At img size: {img_size}")
|
||||
30
ML/Pytorch/GANs/StyleGAN/prepare_data.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from multiprocessing import Pool
|
||||
root_dir = "FFHQ/images1024x1024/"
|
||||
files = os.listdir(root_dir)
|
||||
|
||||
def resize(file, size, folder_to_save):
|
||||
image = Image.open(root_dir + file).resize((size, size), Image.LANCZOS)
|
||||
image.save(folder_to_save+file, quality=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for img_size in [4, 8, 512, 1024]:
|
||||
folder_name = "FFHQ_"+str(img_size)+"/images/"
|
||||
if not os.path.isdir(folder_name):
|
||||
os.makedirs(folder_name)
|
||||
|
||||
data = [(file, img_size, folder_name) for file in files]
|
||||
pool = Pool()
|
||||
pool.starmap(resize, data)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
2
ML/Pytorch/GANs/StyleGAN/readme_important.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
this implementation doesn't work, I need to debug and see where I've made a mistake. It seems to do something
|
||||
that makes sense but it's nowhere near the same level of performance that they had in the original paper.
|
||||
200
ML/Pytorch/GANs/StyleGAN/train.py
Normal file
@@ -0,0 +1,200 @@
|
||||
""" Training of ProGAN using WGAN-GP loss"""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import (
|
||||
gradient_penalty,
|
||||
plot_to_tensorboard,
|
||||
save_checkpoint,
|
||||
load_checkpoint,
|
||||
EMA,
|
||||
)
|
||||
from model import Discriminator, Generator
|
||||
from math import log2
|
||||
from tqdm import tqdm
|
||||
import config
|
||||
|
||||
torch.backends.cudnn.benchmarks = True
|
||||
|
||||
|
||||
def get_loader(image_size):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
#transforms.Resize((image_size, image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(config.CHANNELS_IMG)],
|
||||
[0.5 for _ in range(config.CHANNELS_IMG)],
|
||||
),
|
||||
]
|
||||
)
|
||||
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
|
||||
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader, dataset
|
||||
|
||||
|
||||
def train_fn(
|
||||
critic,
|
||||
gen,
|
||||
loader,
|
||||
dataset,
|
||||
step,
|
||||
alpha,
|
||||
opt_critic,
|
||||
opt_gen,
|
||||
tensorboard_step,
|
||||
writer,
|
||||
scaler_gen,
|
||||
scaler_critic,
|
||||
ema,
|
||||
):
|
||||
loop = tqdm(loader, leave=True)
|
||||
gen2 = Generator(
|
||||
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
|
||||
for batch_idx, (real, _) in enumerate(loop):
|
||||
real = real.to(config.DEVICE)
|
||||
cur_batch_size = real.shape[0]
|
||||
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
|
||||
# which is equivalent to minimizing the negative of the expression
|
||||
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
fake = gen(noise, alpha, step)
|
||||
critic_real = critic(real, alpha, step)
|
||||
critic_fake = critic(fake.detach(), alpha, step)
|
||||
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
+ config.LAMBDA_GP * gp
|
||||
+ (0.001 * torch.mean(critic_real ** 2))
|
||||
)
|
||||
|
||||
opt_critic.zero_grad()
|
||||
scaler_critic.scale(loss_critic).backward()
|
||||
scaler_critic.step(opt_critic)
|
||||
scaler_critic.update()
|
||||
|
||||
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
|
||||
with torch.cuda.amp.autocast():
|
||||
gen_fake = critic(fake, alpha, step)
|
||||
loss_gen = -torch.mean(gen_fake)
|
||||
|
||||
opt_gen.zero_grad()
|
||||
scaler_gen.scale(loss_gen).backward()
|
||||
scaler_gen.step(opt_gen)
|
||||
scaler_gen.update()
|
||||
|
||||
# Update alpha and ensure less than 1
|
||||
alpha += cur_batch_size / (
|
||||
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
|
||||
)
|
||||
alpha = min(alpha, 1)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
ema(gen)
|
||||
with torch.no_grad():
|
||||
ema.copy_weights_to(gen2)
|
||||
fixed_fakes = gen2(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
|
||||
plot_to_tensorboard(
|
||||
writer,
|
||||
loss_critic.item(),
|
||||
loss_gen.item(),
|
||||
real.detach(),
|
||||
fixed_fakes.detach(),
|
||||
tensorboard_step,
|
||||
)
|
||||
tensorboard_step += 1
|
||||
|
||||
loop.set_postfix(
|
||||
gp=gp.item(),
|
||||
loss_critic=loss_critic.item(),
|
||||
)
|
||||
|
||||
|
||||
return tensorboard_step, alpha
|
||||
|
||||
|
||||
def main():
|
||||
# initialize gen and disc, note: discriminator should be called critic,
|
||||
# according to WGAN paper (since it no longer outputs between [0, 1])
|
||||
# but really who cares..
|
||||
gen = Generator(
|
||||
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
critic = Discriminator(
|
||||
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
ema = EMA(gamma=0.999, save_frequency=2000)
|
||||
# initialize optimizers and scalers for FP16 training
|
||||
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
|
||||
{"params": gen.map.parameters(), "lr": 1e-5}], lr=config.LEARNING_RATE, betas=(0.0, 0.99))
|
||||
opt_critic = optim.Adam(
|
||||
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
|
||||
)
|
||||
scaler_critic = torch.cuda.amp.GradScaler()
|
||||
scaler_gen = torch.cuda.amp.GradScaler()
|
||||
|
||||
# for tensorboard plotting
|
||||
writer = SummaryWriter(f"logs/gan")
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
gen.train()
|
||||
critic.train()
|
||||
|
||||
tensorboard_step = 0
|
||||
# start at step that corresponds to img size that we set in config
|
||||
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
|
||||
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
|
||||
alpha = 1e-5 # start with very low alpha
|
||||
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
|
||||
print(f"Current image size: {4 * 2 ** step}")
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}]")
|
||||
tensorboard_step, alpha = train_fn(
|
||||
critic,
|
||||
gen,
|
||||
loader,
|
||||
dataset,
|
||||
step,
|
||||
alpha,
|
||||
opt_critic,
|
||||
opt_gen,
|
||||
tensorboard_step,
|
||||
writer,
|
||||
scaler_gen,
|
||||
scaler_critic,
|
||||
ema,
|
||||
)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
|
||||
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
|
||||
|
||||
step += 1 # progress to the next img size
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
126
ML/Pytorch/GANs/StyleGAN/utils.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
def plot_to_tensorboard(
|
||||
writer, loss_critic, loss_gen, real, fake, tensorboard_step
|
||||
):
|
||||
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
|
||||
|
||||
with torch.no_grad():
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
|
||||
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
|
||||
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * beta + fake.detach() * (1 - beta)
|
||||
interpolated_images.requires_grad_(True)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images, alpha, train_step)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||||
print("=> Loading checkpoint")
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cuda")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
# If we don't do this then it will just have learning rate of old checkpoint
|
||||
# and it will lead to many hours of debugging \:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def seed_everything(seed=42):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
class EMA:
|
||||
# Found this useful (thanks alexis-jacq):
|
||||
# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/3
|
||||
def __init__(self, gamma=0.99, save=True, save_frequency=100, save_filename="ema_weights.pth"):
|
||||
"""
|
||||
Initialize the weight to which we will do the
|
||||
exponential moving average and the dictionary
|
||||
where we store the model parameters
|
||||
"""
|
||||
self.gamma = gamma
|
||||
self.registered = {}
|
||||
self.save_filename = save_filename
|
||||
self.save_frequency = save_frequency
|
||||
self.count = 0
|
||||
|
||||
if save_filename in os.listdir("."):
|
||||
self.registered = torch.load(self.save_filename)
|
||||
|
||||
if not save:
|
||||
warnings.warn("Note that the exponential moving average weights will not be saved to a .pth file!")
|
||||
|
||||
def register_weights(self, model):
|
||||
"""
|
||||
Registers the weights of the model which will
|
||||
later be used when we take the moving average
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.registered[name] = param.clone().detach()
|
||||
|
||||
def __call__(self, model):
|
||||
self.count += 1
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
new_weight = param.clone().detach() if name not in self.registered else self.gamma * param + (1 - self.gamma) * self.registered[name]
|
||||
self.registered[name] = new_weight
|
||||
|
||||
if self.count % self.save_frequency == 0:
|
||||
self.save_ema_weights()
|
||||
|
||||
def copy_weights_to(self, model):
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
param.data = self.registered[name]
|
||||
|
||||
def save_ema_weights(self):
|
||||
torch.save(self.registered, self.save_filename)
|
||||
|
||||
|
||||