diff --git a/ML/Pytorch/GANs/ESRGAN/ESRGAN.png b/ML/Pytorch/GANs/ESRGAN/ESRGAN.png new file mode 100644 index 0000000..5226a30 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/ESRGAN.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/ESRGAN_generator.pth b/ML/Pytorch/GANs/ESRGAN/ESRGAN_generator.pth new file mode 100644 index 0000000..88fa2cc Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/ESRGAN_generator.pth differ diff --git a/ML/Pytorch/GANs/ESRGAN/config.py b/ML/Pytorch/GANs/ESRGAN/config.py new file mode 100644 index 0000000..87a2f5c --- /dev/null +++ b/ML/Pytorch/GANs/ESRGAN/config.py @@ -0,0 +1,48 @@ +import torch +from PIL import Image +import albumentations as A +from albumentations.pytorch import ToTensorV2 + +LOAD_MODEL = True +SAVE_MODEL = True +CHECKPOINT_GEN = "gen.pth" +CHECKPOINT_DISC = "disc.pth" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +LEARNING_RATE = 1e-4 +NUM_EPOCHS = 10000 +BATCH_SIZE = 16 +LAMBDA_GP = 10 +NUM_WORKERS = 4 +HIGH_RES = 128 +LOW_RES = HIGH_RES // 4 +IMG_CHANNELS = 3 + +highres_transform = A.Compose( + [ + A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), + ToTensorV2(), + ] +) + +lowres_transform = A.Compose( + [ + A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC), + A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), + ToTensorV2(), + ] +) + +both_transforms = A.Compose( + [ + A.RandomCrop(width=HIGH_RES, height=HIGH_RES), + A.HorizontalFlip(p=0.5), + A.RandomRotate90(p=0.5), + ] +) + +test_transform = A.Compose( + [ + A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), + ToTensorV2(), + ] +) diff --git a/ML/Pytorch/GANs/ESRGAN/dataset.py b/ML/Pytorch/GANs/ESRGAN/dataset.py new file mode 100644 index 0000000..fc758d5 --- /dev/null +++ b/ML/Pytorch/GANs/ESRGAN/dataset.py @@ -0,0 +1,49 @@ +import torch +from tqdm import tqdm +import time +import torch.nn +import os +from torch.utils.data import Dataset, DataLoader +import numpy as np +import config +from PIL import Image +import cv2 + + +class MyImageFolder(Dataset): + def __init__(self, root_dir): + super(MyImageFolder, self).__init__() + self.data = [] + self.root_dir = root_dir + self.class_names = os.listdir(root_dir) + + for index, name in enumerate(self.class_names): + files = os.listdir(os.path.join(root_dir, name)) + self.data += list(zip(files, [index] * len(files))) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + img_file, label = self.data[index] + root_and_dir = os.path.join(self.root_dir, self.class_names[label]) + + image = cv2.imread(os.path.join(root_and_dir, img_file)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + both_transform = config.both_transforms(image=image)["image"] + low_res = config.lowres_transform(image=both_transform)["image"] + high_res = config.highres_transform(image=both_transform)["image"] + return low_res, high_res + + +def test(): + dataset = MyImageFolder(root_dir="data/") + loader = DataLoader(dataset, batch_size=8) + + for low_res, high_res in loader: + print(low_res.shape) + print(high_res.shape) + + +if __name__ == "__main__": + test() diff --git a/ML/Pytorch/GANs/ESRGAN/loss.py b/ML/Pytorch/GANs/ESRGAN/loss.py new file mode 100644 index 0000000..8bf5135 --- /dev/null +++ b/ML/Pytorch/GANs/ESRGAN/loss.py @@ -0,0 +1,19 @@ +import torch.nn as nn +from torchvision.models import vgg19 +import config + + +class VGGLoss(nn.Module): + def __init__(self): + super().__init__() + self.vgg = vgg19(pretrained=True).features[:35].eval().to(config.DEVICE) + + for param in self.vgg.parameters(): + param.requires_grad = False + + self.loss = nn.MSELoss() + + def forward(self, input, target): + vgg_input_features = self.vgg(input) + vgg_target_features = self.vgg(target) + return self.loss(vgg_input_features, vgg_target_features) diff --git a/ML/Pytorch/GANs/ESRGAN/model.py b/ML/Pytorch/GANs/ESRGAN/model.py new file mode 100644 index 0000000..6702874 --- /dev/null +++ b/ML/Pytorch/GANs/ESRGAN/model.py @@ -0,0 +1,154 @@ +import torch +from torch import nn + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, use_act, **kwargs): + super().__init__() + self.cnn = nn.Conv2d( + in_channels, + out_channels, + **kwargs, + bias=True, + ) + self.act = nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity() + + def forward(self, x): + return self.act(self.cnn(x)) + + +class UpsampleBlock(nn.Module): + def __init__(self, in_c, scale_factor=2): + super().__init__() + self.upsample = nn.Upsample(scale_factor=scale_factor, mode="nearest") + self.conv = nn.Conv2d(in_c, in_c, 3, 1, 1, bias=True) + self.act = nn.LeakyReLU(0.2, inplace=True) + + def forward(self, x): + return self.act(self.conv(self.upsample(x))) + + +class DenseResidualBlock(nn.Module): + def __init__(self, in_channels, channels=32, residual_beta=0.2): + super().__init__() + self.residual_beta = residual_beta + self.blocks = nn.ModuleList() + + for i in range(5): + self.blocks.append( + ConvBlock( + in_channels + channels * i, + channels if i <= 3 else in_channels, + kernel_size=3, + stride=1, + padding=1, + use_act=True if i <= 3 else False, + ) + ) + + def forward(self, x): + new_inputs = x + for block in self.blocks: + out = block(new_inputs) + new_inputs = torch.cat([new_inputs, out], dim=1) + return self.residual_beta * out + x + + +class RRDB(nn.Module): + def __init__(self, in_channels, residual_beta=0.2): + super().__init__() + self.residual_beta = residual_beta + self.rrdb = nn.Sequential(*[DenseResidualBlock(in_channels) for _ in range(3)]) + + def forward(self, x): + return self.rrdb(x) * self.residual_beta + x + + +class Generator(nn.Module): + def __init__(self, in_channels=3, num_channels=64, num_blocks=23): + super().__init__() + self.initial = nn.Conv2d( + in_channels, + num_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.residuals = nn.Sequential(*[RRDB(num_channels) for _ in range(num_blocks)]) + self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1) + self.upsamples = nn.Sequential( + UpsampleBlock(num_channels), UpsampleBlock(num_channels), + ) + self.final = nn.Sequential( + nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(num_channels, in_channels, 3, 1, 1, bias=True), + ) + + def forward(self, x): + initial = self.initial(x) + x = self.conv(self.residuals(initial)) + initial + x = self.upsamples(x) + return self.final(x) + + +class Discriminator(nn.Module): + def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]): + super().__init__() + blocks = [] + for idx, feature in enumerate(features): + blocks.append( + ConvBlock( + in_channels, + feature, + kernel_size=3, + stride=1 + idx % 2, + padding=1, + use_act=True, + ), + ) + in_channels = feature + + self.blocks = nn.Sequential(*blocks) + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d((6, 6)), + nn.Flatten(), + nn.Linear(512 * 6 * 6, 1024), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(1024, 1), + ) + + def forward(self, x): + x = self.blocks(x) + return self.classifier(x) + +def initialize_weights(model, scale=0.1): + for m in model.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data) + m.weight.data *= scale + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight.data) + m.weight.data *= scale + + +def test(): + gen = Generator() + disc = Discriminator() + low_res = 24 + x = torch.randn((5, 3, low_res, low_res)) + gen_out = gen(x) + disc_out = disc(gen_out) + + print(gen_out.shape) + print(disc_out.shape) + +if __name__ == "__main__": + test() + + + + + diff --git a/ML/Pytorch/GANs/ESRGAN/saved/baboon_LR.png b/ML/Pytorch/GANs/ESRGAN/saved/baboon_LR.png new file mode 100644 index 0000000..477946c Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/saved/baboon_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/saved/baby_LR.png b/ML/Pytorch/GANs/ESRGAN/saved/baby_LR.png new file mode 100644 index 0000000..7f1edf2 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/saved/baby_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/saved/butterfly_LR.png b/ML/Pytorch/GANs/ESRGAN/saved/butterfly_LR.png new file mode 100644 index 0000000..7780a5f Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/saved/butterfly_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/saved/comic_LR.png b/ML/Pytorch/GANs/ESRGAN/saved/comic_LR.png new file mode 100644 index 0000000..48a3119 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/saved/comic_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/saved/head_LR.png b/ML/Pytorch/GANs/ESRGAN/saved/head_LR.png new file mode 100644 index 0000000..20bfe2c Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/saved/head_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/saved/woman_LR.png b/ML/Pytorch/GANs/ESRGAN/saved/woman_LR.png new file mode 100644 index 0000000..2613e31 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/saved/woman_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/test_images/baboon_LR.png b/ML/Pytorch/GANs/ESRGAN/test_images/baboon_LR.png new file mode 100644 index 0000000..527ceb2 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/test_images/baboon_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/test_images/baby_LR.png b/ML/Pytorch/GANs/ESRGAN/test_images/baby_LR.png new file mode 100644 index 0000000..17284d4 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/test_images/baby_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/test_images/butterfly_LR.png b/ML/Pytorch/GANs/ESRGAN/test_images/butterfly_LR.png new file mode 100644 index 0000000..8e68243 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/test_images/butterfly_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/test_images/comic_LR.png b/ML/Pytorch/GANs/ESRGAN/test_images/comic_LR.png new file mode 100644 index 0000000..42cad00 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/test_images/comic_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/test_images/head_LR.png b/ML/Pytorch/GANs/ESRGAN/test_images/head_LR.png new file mode 100644 index 0000000..8324a9c Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/test_images/head_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/test_images/woman_LR.png b/ML/Pytorch/GANs/ESRGAN/test_images/woman_LR.png new file mode 100644 index 0000000..f2978f4 Binary files /dev/null and b/ML/Pytorch/GANs/ESRGAN/test_images/woman_LR.png differ diff --git a/ML/Pytorch/GANs/ESRGAN/train.py b/ML/Pytorch/GANs/ESRGAN/train.py new file mode 100644 index 0000000..8bf6109 --- /dev/null +++ b/ML/Pytorch/GANs/ESRGAN/train.py @@ -0,0 +1,154 @@ +import torch +import config +from torch import nn +from torch import optim +from utils import gradient_penalty, load_checkpoint, save_checkpoint, plot_examples +from loss import VGGLoss +from torch.utils.data import DataLoader +from model import Generator, Discriminator, initialize_weights +from tqdm import tqdm +from dataset import MyImageFolder +from torch.utils.tensorboard import SummaryWriter + +torch.backends.cudnn.benchmark = True + +def train_fn( + loader, + disc, + gen, + opt_gen, + opt_disc, + l1, + vgg_loss, + g_scaler, + d_scaler, + writer, + tb_step, +): + loop = tqdm(loader, leave=True) + + for idx, (low_res, high_res) in enumerate(loop): + high_res = high_res.to(config.DEVICE) + low_res = low_res.to(config.DEVICE) + + with torch.cuda.amp.autocast(): + fake = gen(low_res) + critic_real = disc(high_res) + critic_fake = disc(fake.detach()) + gp = gradient_penalty(disc, high_res, fake, device=config.DEVICE) + loss_critic = ( + -(torch.mean(critic_real) - torch.mean(critic_fake)) + + config.LAMBDA_GP * gp + ) + + opt_disc.zero_grad() + d_scaler.scale(loss_critic).backward() + d_scaler.step(opt_disc) + d_scaler.update() + + # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z)) + with torch.cuda.amp.autocast(): + l1_loss = 1e-2 * l1(fake, high_res) + adversarial_loss = 5e-3 * -torch.mean(disc(fake)) + loss_for_vgg = vgg_loss(fake, high_res) + gen_loss = l1_loss + loss_for_vgg + adversarial_loss + + opt_gen.zero_grad() + g_scaler.scale(gen_loss).backward() + g_scaler.step(opt_gen) + g_scaler.update() + + writer.add_scalar("Critic loss", loss_critic.item(), global_step=tb_step) + tb_step += 1 + + if idx % 100 == 0 and idx > 0: + plot_examples("test_images/", gen) + + loop.set_postfix( + gp=gp.item(), + critic=loss_critic.item(), + l1=l1_loss.item(), + vgg=loss_for_vgg.item(), + adversarial=adversarial_loss.item(), + ) + + return tb_step + + +def main(): + dataset = MyImageFolder(root_dir="data/") + loader = DataLoader( + dataset, + batch_size=config.BATCH_SIZE, + shuffle=True, + pin_memory=True, + num_workers=config.NUM_WORKERS, + ) + gen = Generator(in_channels=3).to(config.DEVICE) + disc = Discriminator(in_channels=3).to(config.DEVICE) + initialize_weights(gen) + opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9)) + opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9)) + writer = SummaryWriter("logs") + tb_step = 0 + l1 = nn.L1Loss() + gen.train() + disc.train() + vgg_loss = VGGLoss() + + g_scaler = torch.cuda.amp.GradScaler() + d_scaler = torch.cuda.amp.GradScaler() + + if config.LOAD_MODEL: + load_checkpoint( + config.CHECKPOINT_GEN, + gen, + opt_gen, + config.LEARNING_RATE, + ) + load_checkpoint( + config.CHECKPOINT_DISC, + disc, + opt_disc, + config.LEARNING_RATE, + ) + + + for epoch in range(config.NUM_EPOCHS): + tb_step = train_fn( + loader, + disc, + gen, + opt_gen, + opt_disc, + l1, + vgg_loss, + g_scaler, + d_scaler, + writer, + tb_step, + ) + + if config.SAVE_MODEL: + save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN) + save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC) + + +if __name__ == "__main__": + try_model = True + + if try_model: + # Will just use pretrained weights and run on images + # in test_images/ and save the ones to SR in saved/ + gen = Generator(in_channels=3).to(config.DEVICE) + opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9)) + load_checkpoint( + config.CHECKPOINT_GEN, + gen, + opt_gen, + config.LEARNING_RATE, + ) + plot_examples("test_images/", gen) + else: + # This will train from scratch + main() diff --git a/ML/Pytorch/GANs/ESRGAN/utils.py b/ML/Pytorch/GANs/ESRGAN/utils.py new file mode 100644 index 0000000..1e93787 --- /dev/null +++ b/ML/Pytorch/GANs/ESRGAN/utils.py @@ -0,0 +1,67 @@ +import torch +import os +import config +import numpy as np +from PIL import Image +from torchvision.utils import save_image + + +def gradient_penalty(critic, real, fake, device): + BATCH_SIZE, C, H, W = real.shape + alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) + interpolated_images = real * alpha + fake.detach() * (1 - alpha) + interpolated_images.requires_grad_(True) + + # Calculate critic scores + mixed_scores = critic(interpolated_images) + + # Take the gradient of the scores with respect to the images + gradient = torch.autograd.grad( + inputs=interpolated_images, + outputs=mixed_scores, + grad_outputs=torch.ones_like(mixed_scores), + create_graph=True, + retain_graph=True, + )[0] + gradient = gradient.view(gradient.shape[0], -1) + gradient_norm = gradient.norm(2, dim=1) + gradient_penalty = torch.mean((gradient_norm - 1) ** 2) + return gradient_penalty + + +def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"): + print("=> Saving checkpoint") + checkpoint = { + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + torch.save(checkpoint, filename) + + +def load_checkpoint(checkpoint_file, model, optimizer, lr): + print("=> Loading checkpoint") + checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE) + # model.load_state_dict(checkpoint) + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + # If we don't do this then it will just have learning rate of old checkpoint + # and it will lead to many hours of debugging \: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def plot_examples(low_res_folder, gen): + files = os.listdir(low_res_folder) + + gen.eval() + for file in files: + image = Image.open("test_images/" + file) + with torch.no_grad(): + upscaled_img = gen( + config.test_transform(image=np.asarray(image))["image"] + .unsqueeze(0) + .to(config.DEVICE) + ) + save_image(upscaled_img, f"saved/{file}") + gen.train() diff --git a/ML/Pytorch/GANs/SRGAN/architecture.png b/ML/Pytorch/GANs/SRGAN/architecture.png new file mode 100644 index 0000000..098febc Binary files /dev/null and b/ML/Pytorch/GANs/SRGAN/architecture.png differ diff --git a/ML/Pytorch/GANs/SRGAN/config.py b/ML/Pytorch/GANs/SRGAN/config.py new file mode 100644 index 0000000..b307451 --- /dev/null +++ b/ML/Pytorch/GANs/SRGAN/config.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +import albumentations as A +from albumentations.pytorch import ToTensorV2 + +LOAD_MODEL = True +SAVE_MODEL = True +CHECKPOINT_GEN = "gen.pth.tar" +CHECKPOINT_DISC = "disc.pth.tar" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +LEARNING_RATE = 1e-4 +NUM_EPOCHS = 100 +BATCH_SIZE = 16 +NUM_WORKERS = 4 +HIGH_RES = 96 +LOW_RES = HIGH_RES // 4 +IMG_CHANNELS = 3 + +highres_transform = A.Compose( + [ + A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ToTensorV2(), + ] +) + +lowres_transform = A.Compose( + [ + A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC), + A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), + ToTensorV2(), + ] +) + +both_transforms = A.Compose( + [ + A.RandomCrop(width=HIGH_RES, height=HIGH_RES), + A.HorizontalFlip(p=0.5), + A.RandomRotate90(p=0.5), + ] +) + +test_transform = A.Compose( + [ + A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), + ToTensorV2(), + ] +) diff --git a/ML/Pytorch/GANs/SRGAN/dataset.py b/ML/Pytorch/GANs/SRGAN/dataset.py new file mode 100644 index 0000000..403a67d --- /dev/null +++ b/ML/Pytorch/GANs/SRGAN/dataset.py @@ -0,0 +1,43 @@ +import os +import numpy as np +import config +from torch.utils.data import Dataset, DataLoader +from PIL import Image + + +class MyImageFolder(Dataset): + def __init__(self, root_dir): + super(MyImageFolder, self).__init__() + self.data = [] + self.root_dir = root_dir + self.class_names = os.listdir(root_dir) + + for index, name in enumerate(self.class_names): + files = os.listdir(os.path.join(root_dir, name)) + self.data += list(zip(files, [index] * len(files))) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + img_file, label = self.data[index] + root_and_dir = os.path.join(self.root_dir, self.class_names[label]) + + image = np.array(Image.open(os.path.join(root_and_dir, img_file))) + image = config.both_transforms(image=image)["image"] + high_res = config.highres_transform(image=image)["image"] + low_res = config.lowres_transform(image=image)["image"] + return low_res, high_res + + +def test(): + dataset = MyImageFolder(root_dir="new_data/") + loader = DataLoader(dataset, batch_size=1, num_workers=8) + + for low_res, high_res in loader: + print(low_res.shape) + print(high_res.shape) + + +if __name__ == "__main__": + test() diff --git a/ML/Pytorch/GANs/SRGAN/loss.py b/ML/Pytorch/GANs/SRGAN/loss.py new file mode 100644 index 0000000..6fe9861 --- /dev/null +++ b/ML/Pytorch/GANs/SRGAN/loss.py @@ -0,0 +1,21 @@ +import torch.nn as nn +from torchvision.models import vgg19 +import config + +# phi_5,4 5th conv layer before maxpooling but after activation + +class VGGLoss(nn.Module): + def __init__(self): + super().__init__() + self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE) + self.loss = nn.MSELoss() + + for param in self.vgg.parameters(): + param.requires_grad = False + + def forward(self, input, target): + vgg_input_features = self.vgg(input) + vgg_target_features = self.vgg(target) + return self.loss(vgg_input_features, vgg_target_features) + + diff --git a/ML/Pytorch/GANs/SRGAN/model.py b/ML/Pytorch/GANs/SRGAN/model.py new file mode 100644 index 0000000..52eb81c --- /dev/null +++ b/ML/Pytorch/GANs/SRGAN/model.py @@ -0,0 +1,128 @@ +import torch +from torch import nn + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + discriminator=False, + use_act=True, + use_bn=True, + **kwargs, + ): + super().__init__() + self.use_act = use_act + self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn) + self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity() + self.act = ( + nn.LeakyReLU(0.2, inplace=True) + if discriminator + else nn.PReLU(num_parameters=out_channels) + ) + + def forward(self, x): + return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x)) + + +class UpsampleBlock(nn.Module): + def __init__(self, in_c, scale_factor): + super().__init__() + self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1) + self.ps = nn.PixelShuffle(scale_factor) # in_c * 4, H, W --> in_c, H*2, W*2 + self.act = nn.PReLU(num_parameters=in_c) + + def forward(self, x): + return self.act(self.ps(self.conv(x))) + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.block1 = ConvBlock( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.block2 = ConvBlock( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + use_act=False, + ) + + def forward(self, x): + out = self.block1(x) + out = self.block2(out) + return out + x + + +class Generator(nn.Module): + def __init__(self, in_channels=3, num_channels=64, num_blocks=16): + super().__init__() + self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False) + self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)]) + self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False) + self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2)) + self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4) + + def forward(self, x): + initial = self.initial(x) + x = self.residuals(initial) + x = self.convblock(x) + initial + x = self.upsamples(x) + return torch.tanh(self.final(x)) + + +class Discriminator(nn.Module): + def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]): + super().__init__() + blocks = [] + for idx, feature in enumerate(features): + blocks.append( + ConvBlock( + in_channels, + feature, + kernel_size=3, + stride=1 + idx % 2, + padding=1, + discriminator=True, + use_act=True, + use_bn=False if idx == 0 else True, + ) + ) + in_channels = feature + + self.blocks = nn.Sequential(*blocks) + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d((6, 6)), + nn.Flatten(), + nn.Linear(512*6*6, 1024), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(1024, 1), + ) + + def forward(self, x): + x = self.blocks(x) + return self.classifier(x) + +def test(): + low_resolution = 24 # 96x96 -> 24x24 + with torch.cuda.amp.autocast(): + x = torch.randn((5, 3, low_resolution, low_resolution)) + gen = Generator() + gen_out = gen(x) + disc = Discriminator() + disc_out = disc(gen_out) + + print(gen_out.shape) + print(disc_out.shape) + + +if __name__ == "__main__": + test() diff --git a/ML/Pytorch/GANs/SRGAN/train.py b/ML/Pytorch/GANs/SRGAN/train.py new file mode 100644 index 0000000..5cab4e8 --- /dev/null +++ b/ML/Pytorch/GANs/SRGAN/train.py @@ -0,0 +1,88 @@ +import torch +import config +from torch import nn +from torch import optim +from utils import load_checkpoint, save_checkpoint, plot_examples +from loss import VGGLoss +from torch.utils.data import DataLoader +from model import Generator, Discriminator +from tqdm import tqdm +from dataset import MyImageFolder + +torch.backends.cudnn.benchmark = True + + +def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss): + loop = tqdm(loader, leave=True) + + for idx, (low_res, high_res) in enumerate(loop): + high_res = high_res.to(config.DEVICE) + low_res = low_res.to(config.DEVICE) + + ### Train Discriminator: max log(D(x)) + log(1 - D(G(z))) + fake = gen(low_res) + disc_real = disc(high_res) + disc_fake = disc(fake.detach()) + disc_loss_real = bce( + disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real) + ) + disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake)) + loss_disc = disc_loss_fake + disc_loss_real + + opt_disc.zero_grad() + loss_disc.backward() + opt_disc.step() + + # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z)) + disc_fake = disc(fake) + #l2_loss = mse(fake, high_res) + adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake)) + loss_for_vgg = 0.006 * vgg_loss(fake, high_res) + gen_loss = loss_for_vgg + adversarial_loss + + opt_gen.zero_grad() + gen_loss.backward() + opt_gen.step() + + if idx % 200 == 0: + plot_examples("test_images/", gen) + + +def main(): + dataset = MyImageFolder(root_dir="new_data/") + loader = DataLoader( + dataset, + batch_size=config.BATCH_SIZE, + shuffle=True, + pin_memory=True, + num_workers=config.NUM_WORKERS, + ) + gen = Generator(in_channels=3).to(config.DEVICE) + disc = Discriminator(img_channels=3).to(config.DEVICE) + opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999)) + opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999)) + mse = nn.MSELoss() + bce = nn.BCEWithLogitsLoss() + vgg_loss = VGGLoss() + + if config.LOAD_MODEL: + load_checkpoint( + config.CHECKPOINT_GEN, + gen, + opt_gen, + config.LEARNING_RATE, + ) + load_checkpoint( + config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE, + ) + + for epoch in range(config.NUM_EPOCHS): + train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss) + + if config.SAVE_MODEL: + save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN) + save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ML/Pytorch/GANs/SRGAN/utils.py b/ML/Pytorch/GANs/SRGAN/utils.py new file mode 100644 index 0000000..a9f3899 --- /dev/null +++ b/ML/Pytorch/GANs/SRGAN/utils.py @@ -0,0 +1,66 @@ +import torch +import os +import config +import numpy as np +from PIL import Image +from torchvision.utils import save_image + + +def gradient_penalty(critic, real, fake, device): + BATCH_SIZE, C, H, W = real.shape + alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) + interpolated_images = real * alpha + fake.detach() * (1 - alpha) + interpolated_images.requires_grad_(True) + + # Calculate critic scores + mixed_scores = critic(interpolated_images) + + # Take the gradient of the scores with respect to the images + gradient = torch.autograd.grad( + inputs=interpolated_images, + outputs=mixed_scores, + grad_outputs=torch.ones_like(mixed_scores), + create_graph=True, + retain_graph=True, + )[0] + gradient = gradient.view(gradient.shape[0], -1) + gradient_norm = gradient.norm(2, dim=1) + gradient_penalty = torch.mean((gradient_norm - 1) ** 2) + return gradient_penalty + + +def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"): + print("=> Saving checkpoint") + checkpoint = { + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + torch.save(checkpoint, filename) + + +def load_checkpoint(checkpoint_file, model, optimizer, lr): + print("=> Loading checkpoint") + checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE) + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + # If we don't do this then it will just have learning rate of old checkpoint + # and it will lead to many hours of debugging \: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def plot_examples(low_res_folder, gen): + files = os.listdir(low_res_folder) + + gen.eval() + for file in files: + image = Image.open("test_images/" + file) + with torch.no_grad(): + upscaled_img = gen( + config.test_transform(image=np.asarray(image))["image"] + .unsqueeze(0) + .to(config.DEVICE) + ) + save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}") + gen.train() diff --git a/ML/Pytorch/GANs/StyleGAN/config.py b/ML/Pytorch/GANs/StyleGAN/config.py new file mode 100644 index 0000000..982ddac --- /dev/null +++ b/ML/Pytorch/GANs/StyleGAN/config.py @@ -0,0 +1,25 @@ +import albumentations as A +import cv2 +import torch +from math import log2 + +from albumentations.pytorch import ToTensorV2 +#from utils import seed_everything + +START_TRAIN_AT_IMG_SIZE = 32 +DATASET = 'FFHQ_32' +CHECKPOINT_GEN = "generator.pth" +CHECKPOINT_CRITIC = "critic.pth" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +LOAD_MODEL = False +SAVE_MODEL = True +LEARNING_RATE = 1e-3 +BATCH_SIZES = [32, 32, 32, 32, 32, 16, 8, 4, 2] +CHANNELS_IMG = 3 +Z_DIM = 512 +W_DIM = 512 +IN_CHANNELS = 512 +LAMBDA_GP = 10 +PROGRESSIVE_EPOCHS = [50] * 100 +FIXED_NOISE = torch.randn((8, Z_DIM)).to(DEVICE) +NUM_WORKERS = 6 \ No newline at end of file diff --git a/ML/Pytorch/GANs/StyleGAN/make_resized_data.py b/ML/Pytorch/GANs/StyleGAN/make_resized_data.py new file mode 100644 index 0000000..1291835 --- /dev/null +++ b/ML/Pytorch/GANs/StyleGAN/make_resized_data.py @@ -0,0 +1,9 @@ +import os +from PIL import Image +from tqdm import tqdm + +root_dir = "FFHQ/images1024x1024" + +for file in tqdm(os.listdir(root_dir)): + img = Image.open(root_dir+ "/"+file).resize((128, 128)) + img.save("FFHQ_resized/"+file) diff --git a/ML/Pytorch/GANs/StyleGAN/model.py b/ML/Pytorch/GANs/StyleGAN/model.py new file mode 100644 index 0000000..c19d412 --- /dev/null +++ b/ML/Pytorch/GANs/StyleGAN/model.py @@ -0,0 +1,293 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from math import log2 + +factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32] + +class PixelNorm(nn.Module): + def __init__(self): + super(PixelNorm, self).__init__() + self.epsilon = 1e-8 + + def forward(self, x): + return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon) + + +class MappingNetwork(nn.Module): + def __init__(self, z_dim, w_dim): + super().__init__() + self.mapping = nn.Sequential( + PixelNorm(), + WSLinear(z_dim, w_dim), + nn.ReLU(), + WSLinear(w_dim, w_dim), + nn.ReLU(), + WSLinear(w_dim, w_dim), + nn.ReLU(), + WSLinear(w_dim, w_dim), + nn.ReLU(), + WSLinear(w_dim, w_dim), + nn.ReLU(), + WSLinear(w_dim, w_dim), + nn.ReLU(), + WSLinear(w_dim, w_dim), + nn.ReLU(), + WSLinear(w_dim, w_dim), + ) + + def forward(self, x): + return self.mapping(x) + + +class InjectNoise(nn.Module): + def __init__(self, channels): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1)) + + def forward(self, x): + noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device) + return x + self.weight * noise + +class AdaIN(nn.Module): + def __init__(self, channels, w_dim): + super().__init__() + self.instance_norm = nn.InstanceNorm2d(channels) + self.style_scale = WSLinear(w_dim, channels) + self.style_bias = WSLinear(w_dim, channels) + + def forward(self, x, w): + x = self.instance_norm(x) + style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3) + style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3) + return style_scale * x + style_bias + + +class WSConv2d(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2, + ): + super(WSConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5 + self.bias = self.conv.bias + self.conv.bias = None + + # initialize conv layer + nn.init.normal_(self.conv.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1) + + +class WSLinear(nn.Module): + def __init__( + self, in_features, out_features, gain=2, + ): + super(WSLinear, self).__init__() + self.linear = nn.Linear(in_features, out_features) + self.scale = (gain / in_features)**0.5 + self.bias = self.linear.bias + self.linear.bias = None + + # initialize linear layer + nn.init.normal_(self.linear.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + return self.linear(x * self.scale) + self.bias + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels, w_dim): + super(GenBlock, self).__init__() + self.conv1 = WSConv2d(in_channels, out_channels) + self.conv2 = WSConv2d(out_channels, out_channels) + self.leaky = nn.LeakyReLU(0.2, inplace=True) + self.inject_noise1 = InjectNoise(out_channels) + self.inject_noise2 = InjectNoise(out_channels) + self.adain1 = AdaIN(out_channels, w_dim) + self.adain2 = AdaIN(out_channels, w_dim) + + def forward(self, x, w): + x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w) + x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w) + return x + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + self.conv1 = WSConv2d(in_channels, out_channels) + self.conv2 = WSConv2d(out_channels, out_channels) + self.leaky = nn.LeakyReLU(0.2) + + def forward(self, x): + x = self.leaky(self.conv1(x)) + x = self.leaky(self.conv2(x)) + return x + + +class Generator(nn.Module): + def __init__(self, z_dim, w_dim, in_channels, img_channels=3): + super(Generator, self).__init__() + self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4))) + self.map = MappingNetwork(z_dim, w_dim) + self.initial_adain1 = AdaIN(in_channels, w_dim) + self.initial_adain2 = AdaIN(in_channels, w_dim) + self.initial_noise1 = InjectNoise(in_channels) + self.initial_noise2 = InjectNoise(in_channels) + self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.leaky = nn.LeakyReLU(0.2, inplace=True) + + self.initial_rgb = WSConv2d( + in_channels, img_channels, kernel_size=1, stride=1, padding=0 + ) + self.prog_blocks, self.rgb_layers = ( + nn.ModuleList([]), + nn.ModuleList([self.initial_rgb]), + ) + + for i in range(len(factors) - 1): # -1 to prevent index error because of factors[i+1] + conv_in_c = int(in_channels * factors[i]) + conv_out_c = int(in_channels * factors[i + 1]) + self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim)) + self.rgb_layers.append( + WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0) + ) + + def fade_in(self, alpha, upscaled, generated): + # alpha should be scalar within [0, 1], and upscale.shape == generated.shape + return torch.tanh(alpha * generated + (1 - alpha) * upscaled) + + def forward(self, noise, alpha, steps): + w = self.map(noise) + x = self.initial_adain1(self.initial_noise1(self.starting_constant), w) + x = self.initial_conv(x) + out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w) + + if steps == 0: + return self.initial_rgb(x) + + for step in range(steps): + upscaled = F.interpolate(out, scale_factor=2, mode="bilinear") + out = self.prog_blocks[step](upscaled, w) + + # The number of channels in upscale will stay the same, while + # out which has moved through prog_blocks might change. To ensure + # we can convert both to rgb we use different rgb_layers + # (steps-1) and steps for upscaled, out respectively + final_upscaled = self.rgb_layers[steps - 1](upscaled) + final_out = self.rgb_layers[steps](out) + return self.fade_in(alpha, final_upscaled, final_out) + + +class Discriminator(nn.Module): + def __init__(self, in_channels, img_channels=3): + super(Discriminator, self).__init__() + self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([]) + self.leaky = nn.LeakyReLU(0.2) + + # here we work back ways from factors because the discriminator + # should be mirrored from the generator. So the first prog_block and + # rgb layer we append will work for input size 1024x1024, then 512->256-> etc + for i in range(len(factors) - 1, 0, -1): + conv_in = int(in_channels * factors[i]) + conv_out = int(in_channels * factors[i - 1]) + self.prog_blocks.append(ConvBlock(conv_in, conv_out)) + self.rgb_layers.append( + WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0) + ) + + # perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size + # did this to "mirror" the generator initial_rgb + self.initial_rgb = WSConv2d( + img_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.rgb_layers.append(self.initial_rgb) + self.avg_pool = nn.AvgPool2d( + kernel_size=2, stride=2 + ) # down sampling using avg pool + + # this is the block for 4x4 input size + self.final_block = nn.Sequential( + # +1 to in_channels because we concatenate from MiniBatch std + WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1), + nn.LeakyReLU(0.2), + WSConv2d( + in_channels, 1, kernel_size=1, padding=0, stride=1 + ), # we use this instead of linear layer + ) + + def fade_in(self, alpha, downscaled, out): + """Used to fade in downscaled using avg pooling and output from CNN""" + # alpha should be scalar within [0, 1], and upscale.shape == generated.shape + return alpha * out + (1 - alpha) * downscaled + + def minibatch_std(self, x): + batch_statistics = ( + torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]) + ) + # we take the std for each example (across all channels, and pixels) then we repeat it + # for a single channel and concatenate it with the image. In this way the discriminator + # will get information about the variation in the batch/image + return torch.cat([x, batch_statistics], dim=1) + + def forward(self, x, alpha, steps): + # where we should start in the list of prog_blocks, maybe a bit confusing but + # the last is for the 4x4. So example let's say steps=1, then we should start + # at the second to last because input_size will be 8x8. If steps==0 we just + # use the final block + cur_step = len(self.prog_blocks) - steps + + # convert from rgb as initial step, this will depend on + # the image size (each will have it's on rgb layer) + out = self.leaky(self.rgb_layers[cur_step](x)) + + if steps == 0: # i.e, image is 4x4 + out = self.minibatch_std(out) + return self.final_block(out).view(out.shape[0], -1) + + # because prog_blocks might change the channels, for down scale we use rgb_layer + # from previous/smaller size which in our case correlates to +1 in the indexing + downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x))) + out = self.avg_pool(self.prog_blocks[cur_step](out)) + + # the fade_in is done first between the downscaled and the input + # this is opposite from the generator + out = self.fade_in(alpha, downscaled, out) + + for step in range(cur_step + 1, len(self.prog_blocks)): + out = self.prog_blocks[step](out) + out = self.avg_pool(out) + + out = self.minibatch_std(out) + return self.final_block(out).view(out.shape[0], -1) + + +if __name__ == "__main__": + Z_DIM = 512 + W_DIM = 512 + IN_CHANNELS = 512 + gen = Generator(Z_DIM, W_DIM, IN_CHANNELS, img_channels=3).to("cuda") + disc = Discriminator(IN_CHANNELS, img_channels=3).to("cuda") + + tot = 0 + for param in gen.parameters(): + tot += param.numel() + + print(tot) + import sys + sys.exit() + + + for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]: + num_steps = int(log2(img_size / 4)) + x = torch.randn((2, Z_DIM)).to("cuda") + z = gen(x, 0.5, steps=num_steps) + assert z.shape == (2, 3, img_size, img_size) + out = disc(z, alpha=0.5, steps=num_steps) + assert out.shape == (2, 1) + print(f"Success! At img size: {img_size}") diff --git a/ML/Pytorch/GANs/StyleGAN/prepare_data.py b/ML/Pytorch/GANs/StyleGAN/prepare_data.py new file mode 100644 index 0000000..b440123 --- /dev/null +++ b/ML/Pytorch/GANs/StyleGAN/prepare_data.py @@ -0,0 +1,30 @@ +from PIL import Image +from tqdm import tqdm +import os +from multiprocessing import Pool +root_dir = "FFHQ/images1024x1024/" +files = os.listdir(root_dir) + +def resize(file, size, folder_to_save): + image = Image.open(root_dir + file).resize((size, size), Image.LANCZOS) + image.save(folder_to_save+file, quality=100) + + +if __name__ == "__main__": + for img_size in [4, 8, 512, 1024]: + folder_name = "FFHQ_"+str(img_size)+"/images/" + if not os.path.isdir(folder_name): + os.makedirs(folder_name) + + data = [(file, img_size, folder_name) for file in files] + pool = Pool() + pool.starmap(resize, data) + + + + + + + + + diff --git a/ML/Pytorch/GANs/StyleGAN/readme_important.txt b/ML/Pytorch/GANs/StyleGAN/readme_important.txt new file mode 100644 index 0000000..eca4a49 --- /dev/null +++ b/ML/Pytorch/GANs/StyleGAN/readme_important.txt @@ -0,0 +1,2 @@ +this implementation doesn't work, I need to debug and see where I've made a mistake. It seems to do something +that makes sense but it's nowhere near the same level of performance that they had in the original paper. \ No newline at end of file diff --git a/ML/Pytorch/GANs/StyleGAN/train.py b/ML/Pytorch/GANs/StyleGAN/train.py new file mode 100644 index 0000000..212b587 --- /dev/null +++ b/ML/Pytorch/GANs/StyleGAN/train.py @@ -0,0 +1,200 @@ +""" Training of ProGAN using WGAN-GP loss""" + +import torch +import torch.optim as optim +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from utils import ( + gradient_penalty, + plot_to_tensorboard, + save_checkpoint, + load_checkpoint, + EMA, +) +from model import Discriminator, Generator +from math import log2 +from tqdm import tqdm +import config + +torch.backends.cudnn.benchmarks = True + + +def get_loader(image_size): + transform = transforms.Compose( + [ + #transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize( + [0.5 for _ in range(config.CHANNELS_IMG)], + [0.5 for _ in range(config.CHANNELS_IMG)], + ), + ] + ) + batch_size = config.BATCH_SIZES[int(log2(image_size / 4))] + dataset = datasets.ImageFolder(root=config.DATASET, transform=transform) + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=config.NUM_WORKERS, + pin_memory=True, + ) + return loader, dataset + + +def train_fn( + critic, + gen, + loader, + dataset, + step, + alpha, + opt_critic, + opt_gen, + tensorboard_step, + writer, + scaler_gen, + scaler_critic, + ema, +): + loop = tqdm(loader, leave=True) + gen2 = Generator( + config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG + ).to(config.DEVICE) + + for batch_idx, (real, _) in enumerate(loop): + real = real.to(config.DEVICE) + cur_batch_size = real.shape[0] + + # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)] + # which is equivalent to minimizing the negative of the expression + noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE) + + with torch.cuda.amp.autocast(): + fake = gen(noise, alpha, step) + critic_real = critic(real, alpha, step) + critic_fake = critic(fake.detach(), alpha, step) + gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE) + loss_critic = ( + -(torch.mean(critic_real) - torch.mean(critic_fake)) + + config.LAMBDA_GP * gp + + (0.001 * torch.mean(critic_real ** 2)) + ) + + opt_critic.zero_grad() + scaler_critic.scale(loss_critic).backward() + scaler_critic.step(opt_critic) + scaler_critic.update() + + # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)] + with torch.cuda.amp.autocast(): + gen_fake = critic(fake, alpha, step) + loss_gen = -torch.mean(gen_fake) + + opt_gen.zero_grad() + scaler_gen.scale(loss_gen).backward() + scaler_gen.step(opt_gen) + scaler_gen.update() + + # Update alpha and ensure less than 1 + alpha += cur_batch_size / ( + (config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset) + ) + alpha = min(alpha, 1) + + if batch_idx % 100 == 0: + ema(gen) + with torch.no_grad(): + ema.copy_weights_to(gen2) + fixed_fakes = gen2(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5 + plot_to_tensorboard( + writer, + loss_critic.item(), + loss_gen.item(), + real.detach(), + fixed_fakes.detach(), + tensorboard_step, + ) + tensorboard_step += 1 + + loop.set_postfix( + gp=gp.item(), + loss_critic=loss_critic.item(), + ) + + + return tensorboard_step, alpha + + +def main(): + # initialize gen and disc, note: discriminator should be called critic, + # according to WGAN paper (since it no longer outputs between [0, 1]) + # but really who cares.. + gen = Generator( + config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG + ).to(config.DEVICE) + critic = Discriminator( + config.IN_CHANNELS, img_channels=config.CHANNELS_IMG + ).to(config.DEVICE) + ema = EMA(gamma=0.999, save_frequency=2000) + # initialize optimizers and scalers for FP16 training + opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]}, + {"params": gen.map.parameters(), "lr": 1e-5}], lr=config.LEARNING_RATE, betas=(0.0, 0.99)) + opt_critic = optim.Adam( + critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99) + ) + scaler_critic = torch.cuda.amp.GradScaler() + scaler_gen = torch.cuda.amp.GradScaler() + + # for tensorboard plotting + writer = SummaryWriter(f"logs/gan") + + if config.LOAD_MODEL: + load_checkpoint( + config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE, + ) + load_checkpoint( + config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE, + ) + + gen.train() + critic.train() + + tensorboard_step = 0 + # start at step that corresponds to img size that we set in config + step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4)) + for num_epochs in config.PROGRESSIVE_EPOCHS[step:]: + alpha = 1e-5 # start with very low alpha + loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4 + print(f"Current image size: {4 * 2 ** step}") + + for epoch in range(num_epochs): + print(f"Epoch [{epoch+1}/{num_epochs}]") + tensorboard_step, alpha = train_fn( + critic, + gen, + loader, + dataset, + step, + alpha, + opt_critic, + opt_gen, + tensorboard_step, + writer, + scaler_gen, + scaler_critic, + ema, + ) + + if config.SAVE_MODEL: + save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN) + save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC) + + step += 1 # progress to the next img size + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ML/Pytorch/GANs/StyleGAN/utils.py b/ML/Pytorch/GANs/StyleGAN/utils.py new file mode 100644 index 0000000..e3d8303 --- /dev/null +++ b/ML/Pytorch/GANs/StyleGAN/utils.py @@ -0,0 +1,126 @@ +import torch +import random +import numpy as np +import os +import torchvision +import torch.nn as nn +import warnings + +# Print losses occasionally and print to tensorboard +def plot_to_tensorboard( + writer, loss_critic, loss_gen, real, fake, tensorboard_step +): + writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step) + + with torch.no_grad(): + # take out (up to) 32 examples + img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True) + img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True) + writer.add_image("Real", img_grid_real, global_step=tensorboard_step) + writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step) + + +def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"): + BATCH_SIZE, C, H, W = real.shape + beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) + interpolated_images = real * beta + fake.detach() * (1 - beta) + interpolated_images.requires_grad_(True) + + # Calculate critic scores + mixed_scores = critic(interpolated_images, alpha, train_step) + + # Take the gradient of the scores with respect to the images + gradient = torch.autograd.grad( + inputs=interpolated_images, + outputs=mixed_scores, + grad_outputs=torch.ones_like(mixed_scores), + create_graph=True, + retain_graph=True, + )[0] + gradient = gradient.view(gradient.shape[0], -1) + gradient_norm = gradient.norm(2, dim=1) + gradient_penalty = torch.mean((gradient_norm - 1) ** 2) + return gradient_penalty + + +def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"): + print("=> Saving checkpoint") + checkpoint = { + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + torch.save(checkpoint, filename) + + +def load_checkpoint(checkpoint_file, model, optimizer, lr): + print("=> Loading checkpoint") + checkpoint = torch.load(checkpoint_file, map_location="cuda") + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + # If we don't do this then it will just have learning rate of old checkpoint + # and it will lead to many hours of debugging \: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def seed_everything(seed=42): + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +class EMA: + # Found this useful (thanks alexis-jacq): + # https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/3 + def __init__(self, gamma=0.99, save=True, save_frequency=100, save_filename="ema_weights.pth"): + """ + Initialize the weight to which we will do the + exponential moving average and the dictionary + where we store the model parameters + """ + self.gamma = gamma + self.registered = {} + self.save_filename = save_filename + self.save_frequency = save_frequency + self.count = 0 + + if save_filename in os.listdir("."): + self.registered = torch.load(self.save_filename) + + if not save: + warnings.warn("Note that the exponential moving average weights will not be saved to a .pth file!") + + def register_weights(self, model): + """ + Registers the weights of the model which will + later be used when we take the moving average + """ + for name, param in model.named_parameters(): + if param.requires_grad: + self.registered[name] = param.clone().detach() + + def __call__(self, model): + self.count += 1 + for name, param in model.named_parameters(): + if param.requires_grad: + new_weight = param.clone().detach() if name not in self.registered else self.gamma * param + (1 - self.gamma) * self.registered[name] + self.registered[name] = new_weight + + if self.count % self.save_frequency == 0: + self.save_ema_weights() + + def copy_weights_to(self, model): + for name, param in model.named_parameters(): + if param.requires_grad: + param.data = self.registered[name] + + def save_ema_weights(self): + torch.save(self.registered, self.save_filename) + +