checked GAN code

This commit is contained in:
Aladdin Persson
2022-12-21 14:03:08 +01:00
parent b6985eccc9
commit c646ef65e2
14 changed files with 225 additions and 270 deletions

View File

@@ -1,3 +1,12 @@
"""
Simple GAN using fully connected layers
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
import torch.nn as nn
import torch.optim as optim
@@ -48,7 +57,10 @@ disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
@@ -104,4 +116,4 @@ for epoch in range(num_epochs):
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
step += 1
step += 1

View File

@@ -1,5 +1,9 @@
"""
Discriminator and Generator implementation from DCGAN paper
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -11,9 +15,7 @@ class Discriminator(nn.Module):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
# input: N x channels_img x 64 x 64
nn.Conv2d(
channels_img, features_d, kernel_size=4, stride=2, padding=1
),
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
# _block(in_channels, out_channels, kernel_size, stride, padding)
self._block(features_d, features_d * 2, 4, 2, 1),
@@ -34,7 +36,7 @@ class Discriminator(nn.Module):
padding,
bias=False,
),
#nn.BatchNorm2d(out_channels),
# nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2),
)
@@ -68,7 +70,7 @@ class Generator(nn.Module):
padding,
bias=False,
),
#nn.BatchNorm2d(out_channels),
# nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
@@ -82,6 +84,7 @@ def initialize_weights(model):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
def test():
N, in_channels, H, W = 8, 3, 64, 64
noise_dim = 100
@@ -91,6 +94,8 @@ def test():
gen = Generator(noise_dim, in_channels, 8)
z = torch.randn((N, noise_dim, 1, 1))
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
print("Success, tests passed!")
# test()
if __name__ == "__main__":
test()

View File

@@ -1,6 +1,10 @@
"""
Training of DCGAN network on MNIST dataset with Discriminator
and Generator imported from models.py
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -35,11 +39,12 @@ transforms = transforms.Compose(
)
# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
download=True)
dataset = datasets.MNIST(
root="dataset/", train=True, transform=transforms, download=True
)
# comment mnist above and uncomment below if train on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
@@ -92,14 +97,10 @@ for epoch in range(NUM_EPOCHS):
with torch.no_grad():
fake = gen(fixed_noise)
# take out (up to) 32 examples
img_grid_real = torchvision.utils.make_grid(
real[:32], normalize=True
)
img_grid_fake = torchvision.utils.make_grid(
fake[:32], normalize=True
)
img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1
step += 1

View File

@@ -2,6 +2,10 @@
Discriminator and Generator implementation from DCGAN paper,
with removed Sigmoid() as output from Discriminator (and therefor
it should be called critic)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -93,6 +97,7 @@ def test():
gen = Generator(noise_dim, in_channels, 8)
z = torch.randn((N, noise_dim, 1, 1))
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
print("Success, tests passed!")
# test()
if __name__ == "__main__":
test()

View File

@@ -1,5 +1,9 @@
"""
Training of DCGAN network with WGAN loss
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -9,6 +13,7 @@ import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights
@@ -61,7 +66,7 @@ critic.train()
for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised
for batch_idx, (data, _) in enumerate(loader):
for batch_idx, (data, _) in enumerate(tqdm(loader)):
data = data.to(device)
cur_batch_size = data.shape[0]
@@ -111,4 +116,4 @@ for epoch in range(NUM_EPOCHS):
step += 1
gen.train()
critic.train()
critic.train()

View File

@@ -1,5 +1,9 @@
"""
Discriminator and Generator implementation from DCGAN paper
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -24,7 +28,12 @@ class Discriminator(nn.Module):
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
),
nn.InstanceNorm2d(out_channels, affine=True),
nn.LeakyReLU(0.2),
@@ -53,7 +62,12 @@ class Generator(nn.Module):
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),

View File

@@ -1,5 +1,9 @@
"""
Training of WGAN-GP
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -10,6 +14,7 @@ import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import gradient_penalty, save_checkpoint, load_checkpoint
from model import Discriminator, Generator, initialize_weights
@@ -31,13 +36,14 @@ transforms = transforms.Compose(
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
),
]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
# comment mnist above and uncomment below for training on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
@@ -66,7 +72,7 @@ critic.train()
for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised
for batch_idx, (real, _) in enumerate(loader):
for batch_idx, (real, _) in enumerate(tqdm(loader)):
real = real.to(device)
cur_batch_size = real.shape[0]
@@ -108,4 +114,4 @@ for epoch in range(NUM_EPOCHS):
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1
step += 1

View File

@@ -11,7 +11,7 @@ LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 10
LOAD_MODEL = True
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
@@ -24,6 +24,6 @@ transforms = A.Compose(
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
],
additional_targets={"image0": "image"},
)
)

View File

@@ -1,11 +1,28 @@
"""
Discriminator model for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
import torch.nn as nn
class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
nn.Conv2d(
in_channels,
out_channels,
4,
stride,
1,
bias=True,
padding_mode="reflect",
),
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
@@ -32,15 +49,27 @@ class Discriminator(nn.Module):
layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
layers.append(
Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
)
in_channels = feature
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
layers.append(
nn.Conv2d(
in_channels,
1,
kernel_size=4,
stride=1,
padding=1,
padding_mode="reflect",
)
)
self.model = nn.Sequential(*layers)
def forward(self, x):
x = self.initial(x)
return torch.sigmoid(self.model(x))
def test():
x = torch.randn((5, 3, 256, 256))
model = Discriminator(in_channels=3)
@@ -50,4 +79,3 @@ def test():
if __name__ == "__main__":
test()

View File

@@ -1,6 +1,15 @@
"""
Generator model for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
super().__init__()
@@ -9,12 +18,13 @@ class ConvBlock(nn.Module):
if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity()
nn.ReLU(inplace=True) if use_act else nn.Identity(),
)
def forward(self, x):
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
@@ -26,31 +36,70 @@ class ResidualBlock(nn.Module):
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, img_channels, num_features = 64, num_residuals=9):
def __init__(self, img_channels, num_features=64, num_residuals=9):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.Conv2d(
img_channels,
num_features,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True),
)
self.down_blocks = nn.ModuleList(
[
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
ConvBlock(
num_features, num_features * 2, kernel_size=3, stride=2, padding=1
),
ConvBlock(
num_features * 2,
num_features * 4,
kernel_size=3,
stride=2,
padding=1,
),
]
)
self.res_blocks = nn.Sequential(
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
*[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
)
self.up_blocks = nn.ModuleList(
[
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvBlock(
num_features * 4,
num_features * 2,
down=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
ConvBlock(
num_features * 2,
num_features * 1,
down=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
]
)
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
self.last = nn.Conv2d(
num_features * 1,
img_channels,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
)
def forward(self, x):
x = self.initial(x)
@@ -61,6 +110,7 @@ class Generator(nn.Module):
x = layer(x)
return torch.tanh(self.last(x))
def test():
img_channels = 3
img_size = 256
@@ -68,5 +118,6 @@ def test():
gen = Generator(img_channels, 9)
print(gen(x).shape)
if __name__ == "__main__":
test()

View File

@@ -1,3 +1,11 @@
"""
Training for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
from dataset import HorseZebraDataset
import sys
@@ -11,7 +19,10 @@ from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
def train_fn(
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
H_reals = 0
H_fakes = 0
loop = tqdm(loader, leave=True)
@@ -39,7 +50,7 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
# put it togethor
D_loss = (D_H_loss + D_Z_loss)/2
D_loss = (D_H_loss + D_Z_loss) / 2
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
@@ -82,11 +93,10 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
g_scaler.update()
if idx % 200 == 0:
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
def main():
@@ -111,23 +121,39 @@ def main():
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
config.CHECKPOINT_GEN_H,
gen_H,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
config.CHECKPOINT_GEN_Z,
gen_Z,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
config.CHECKPOINT_CRITIC_H,
disc_H,
opt_disc,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
config.CHECKPOINT_CRITIC_Z,
disc_Z,
opt_disc,
config.LEARNING_RATE,
)
dataset = HorseZebraDataset(
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
root_horse=config.TRAIN_DIR + "/horses",
root_zebra=config.TRAIN_DIR + "/zebras",
transform=config.transforms,
)
val_dataset = HorseZebraDataset(
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
root_horse="cyclegan_test/horse1",
root_zebra="cyclegan_test/zebra1",
transform=config.transforms,
)
val_loader = DataLoader(
val_dataset,
@@ -140,13 +166,25 @@ def main():
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True
pin_memory=True,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(config.NUM_EPOCHS):
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
train_fn(
disc_H,
disc_Z,
gen_Z,
gen_H,
loader,
opt_disc,
opt_gen,
L1,
mse,
d_scaler,
g_scaler,
)
if config.SAVE_MODEL:
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
@@ -154,5 +192,6 @@ def main():
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
if __name__ == "__main__":
main()
main()