checked GAN code

This commit is contained in:
Aladdin Persson
2022-12-21 14:03:08 +01:00
parent b6985eccc9
commit c646ef65e2
14 changed files with 225 additions and 270 deletions

View File

@@ -1,3 +1,12 @@
"""
Simple GAN using fully connected layers
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
@@ -48,7 +57,10 @@ disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device) gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device) fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),] [
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
) )
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True) dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
@@ -104,4 +116,4 @@ for epoch in range(num_epochs):
writer_real.add_image( writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step "Mnist Real Images", img_grid_real, global_step=step
) )
step += 1 step += 1

View File

@@ -1,5 +1,9 @@
""" """
Discriminator and Generator implementation from DCGAN paper Discriminator and Generator implementation from DCGAN paper
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
""" """
import torch import torch
@@ -11,9 +15,7 @@ class Discriminator(nn.Module):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.disc = nn.Sequential( self.disc = nn.Sequential(
# input: N x channels_img x 64 x 64 # input: N x channels_img x 64 x 64
nn.Conv2d( nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
channels_img, features_d, kernel_size=4, stride=2, padding=1
),
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
# _block(in_channels, out_channels, kernel_size, stride, padding) # _block(in_channels, out_channels, kernel_size, stride, padding)
self._block(features_d, features_d * 2, 4, 2, 1), self._block(features_d, features_d * 2, 4, 2, 1),
@@ -34,7 +36,7 @@ class Discriminator(nn.Module):
padding, padding,
bias=False, bias=False,
), ),
#nn.BatchNorm2d(out_channels), # nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
) )
@@ -68,7 +70,7 @@ class Generator(nn.Module):
padding, padding,
bias=False, bias=False,
), ),
#nn.BatchNorm2d(out_channels), # nn.BatchNorm2d(out_channels),
nn.ReLU(), nn.ReLU(),
) )
@@ -82,6 +84,7 @@ def initialize_weights(model):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02) nn.init.normal_(m.weight.data, 0.0, 0.02)
def test(): def test():
N, in_channels, H, W = 8, 3, 64, 64 N, in_channels, H, W = 8, 3, 64, 64
noise_dim = 100 noise_dim = 100
@@ -91,6 +94,8 @@ def test():
gen = Generator(noise_dim, in_channels, 8) gen = Generator(noise_dim, in_channels, 8)
z = torch.randn((N, noise_dim, 1, 1)) z = torch.randn((N, noise_dim, 1, 1))
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed" assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
print("Success, tests passed!")
# test() if __name__ == "__main__":
test()

View File

@@ -1,6 +1,10 @@
""" """
Training of DCGAN network on MNIST dataset with Discriminator Training of DCGAN network on MNIST dataset with Discriminator
and Generator imported from models.py and Generator imported from models.py
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
""" """
import torch import torch
@@ -35,11 +39,12 @@ transforms = transforms.Compose(
) )
# If you train on MNIST, remember to set channels_img to 1 # If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, dataset = datasets.MNIST(
download=True) root="dataset/", train=True, transform=transforms, download=True
)
# comment mnist above and uncomment below if train on CelebA # comment mnist above and uncomment below if train on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms) # dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device) gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device) disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
@@ -92,14 +97,10 @@ for epoch in range(NUM_EPOCHS):
with torch.no_grad(): with torch.no_grad():
fake = gen(fixed_noise) fake = gen(fixed_noise)
# take out (up to) 32 examples # take out (up to) 32 examples
img_grid_real = torchvision.utils.make_grid( img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
real[:32], normalize=True img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
)
img_grid_fake = torchvision.utils.make_grid(
fake[:32], normalize=True
)
writer_real.add_image("Real", img_grid_real, global_step=step) writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step) writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1 step += 1

View File

@@ -2,6 +2,10 @@
Discriminator and Generator implementation from DCGAN paper, Discriminator and Generator implementation from DCGAN paper,
with removed Sigmoid() as output from Discriminator (and therefor with removed Sigmoid() as output from Discriminator (and therefor
it should be called critic) it should be called critic)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
""" """
import torch import torch
@@ -93,6 +97,7 @@ def test():
gen = Generator(noise_dim, in_channels, 8) gen = Generator(noise_dim, in_channels, 8)
z = torch.randn((N, noise_dim, 1, 1)) z = torch.randn((N, noise_dim, 1, 1))
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed" assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
print("Success, tests passed!")
if __name__ == "__main__":
# test() test()

View File

@@ -1,5 +1,9 @@
""" """
Training of DCGAN network with WGAN loss Training of DCGAN network with WGAN loss
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
""" """
import torch import torch
@@ -9,6 +13,7 @@ import torchvision
import torchvision.datasets as datasets import torchvision.datasets as datasets
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights from model import Discriminator, Generator, initialize_weights
@@ -61,7 +66,7 @@ critic.train()
for epoch in range(NUM_EPOCHS): for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised # Target labels not needed! <3 unsupervised
for batch_idx, (data, _) in enumerate(loader): for batch_idx, (data, _) in enumerate(tqdm(loader)):
data = data.to(device) data = data.to(device)
cur_batch_size = data.shape[0] cur_batch_size = data.shape[0]
@@ -111,4 +116,4 @@ for epoch in range(NUM_EPOCHS):
step += 1 step += 1
gen.train() gen.train()
critic.train() critic.train()

View File

@@ -1,5 +1,9 @@
""" """
Discriminator and Generator implementation from DCGAN paper Discriminator and Generator implementation from DCGAN paper
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
""" """
import torch import torch
@@ -24,7 +28,12 @@ class Discriminator(nn.Module):
def _block(self, in_channels, out_channels, kernel_size, stride, padding): def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential( return nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False, in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
), ),
nn.InstanceNorm2d(out_channels, affine=True), nn.InstanceNorm2d(out_channels, affine=True),
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
@@ -53,7 +62,12 @@ class Generator(nn.Module):
def _block(self, in_channels, out_channels, kernel_size, stride, padding): def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential( return nn.Sequential(
nn.ConvTranspose2d( nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False, in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
), ),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
nn.ReLU(), nn.ReLU(),

View File

@@ -1,5 +1,9 @@
""" """
Training of WGAN-GP Training of WGAN-GP
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
""" """
import torch import torch
@@ -10,6 +14,7 @@ import torchvision.datasets as datasets
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import gradient_penalty, save_checkpoint, load_checkpoint from utils import gradient_penalty, save_checkpoint, load_checkpoint
from model import Discriminator, Generator, initialize_weights from model import Discriminator, Generator, initialize_weights
@@ -31,13 +36,14 @@ transforms = transforms.Compose(
transforms.Resize(IMAGE_SIZE), transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]), [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
),
] ]
) )
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True) dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
# comment mnist above and uncomment below for training on CelebA # comment mnist above and uncomment below for training on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms) # dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
@@ -66,7 +72,7 @@ critic.train()
for epoch in range(NUM_EPOCHS): for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised # Target labels not needed! <3 unsupervised
for batch_idx, (real, _) in enumerate(loader): for batch_idx, (real, _) in enumerate(tqdm(loader)):
real = real.to(device) real = real.to(device)
cur_batch_size = real.shape[0] cur_batch_size = real.shape[0]
@@ -108,4 +114,4 @@ for epoch in range(NUM_EPOCHS):
writer_real.add_image("Real", img_grid_real, global_step=step) writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step) writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1 step += 1

View File

@@ -11,7 +11,7 @@ LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10 LAMBDA_CYCLE = 10
NUM_WORKERS = 4 NUM_WORKERS = 4
NUM_EPOCHS = 10 NUM_EPOCHS = 10
LOAD_MODEL = True LOAD_MODEL = False
SAVE_MODEL = True SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar" CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar" CHECKPOINT_GEN_Z = "genz.pth.tar"
@@ -24,6 +24,6 @@ transforms = A.Compose(
A.HorizontalFlip(p=0.5), A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255), A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(), ToTensorV2(),
], ],
additional_targets={"image0": "image"}, additional_targets={"image0": "image"},
) )

View File

@@ -1,11 +1,28 @@
"""
Discriminator model for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
class Block(nn.Module): class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride): def __init__(self, in_channels, out_channels, stride):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"), nn.Conv2d(
in_channels,
out_channels,
4,
stride,
1,
bias=True,
padding_mode="reflect",
),
nn.InstanceNorm2d(out_channels), nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
) )
@@ -32,15 +49,27 @@ class Discriminator(nn.Module):
layers = [] layers = []
in_channels = features[0] in_channels = features[0]
for feature in features[1:]: for feature in features[1:]:
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2)) layers.append(
Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
)
in_channels = feature in_channels = feature
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")) layers.append(
nn.Conv2d(
in_channels,
1,
kernel_size=4,
stride=1,
padding=1,
padding_mode="reflect",
)
)
self.model = nn.Sequential(*layers) self.model = nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
x = self.initial(x) x = self.initial(x)
return torch.sigmoid(self.model(x)) return torch.sigmoid(self.model(x))
def test(): def test():
x = torch.randn((5, 3, 256, 256)) x = torch.randn((5, 3, 256, 256))
model = Discriminator(in_channels=3) model = Discriminator(in_channels=3)
@@ -50,4 +79,3 @@ def test():
if __name__ == "__main__": if __name__ == "__main__":
test() test()

View File

@@ -1,6 +1,15 @@
"""
Generator model for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
class ConvBlock(nn.Module): class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs): def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
super().__init__() super().__init__()
@@ -9,12 +18,13 @@ class ConvBlock(nn.Module):
if down if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs), else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels), nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity() nn.ReLU(inplace=True) if use_act else nn.Identity(),
) )
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super().__init__() super().__init__()
@@ -26,31 +36,70 @@ class ResidualBlock(nn.Module):
def forward(self, x): def forward(self, x):
return x + self.block(x) return x + self.block(x)
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, img_channels, num_features = 64, num_residuals=9): def __init__(self, img_channels, num_features=64, num_residuals=9):
super().__init__() super().__init__()
self.initial = nn.Sequential( self.initial = nn.Sequential(
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"), nn.Conv2d(
img_channels,
num_features,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
),
nn.InstanceNorm2d(num_features), nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
self.down_blocks = nn.ModuleList( self.down_blocks = nn.ModuleList(
[ [
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1), ConvBlock(
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1), num_features, num_features * 2, kernel_size=3, stride=2, padding=1
),
ConvBlock(
num_features * 2,
num_features * 4,
kernel_size=3,
stride=2,
padding=1,
),
] ]
) )
self.res_blocks = nn.Sequential( self.res_blocks = nn.Sequential(
*[ResidualBlock(num_features*4) for _ in range(num_residuals)] *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
) )
self.up_blocks = nn.ModuleList( self.up_blocks = nn.ModuleList(
[ [
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1), ConvBlock(
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1), num_features * 4,
num_features * 2,
down=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
ConvBlock(
num_features * 2,
num_features * 1,
down=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
] ]
) )
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect") self.last = nn.Conv2d(
num_features * 1,
img_channels,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
)
def forward(self, x): def forward(self, x):
x = self.initial(x) x = self.initial(x)
@@ -61,6 +110,7 @@ class Generator(nn.Module):
x = layer(x) x = layer(x)
return torch.tanh(self.last(x)) return torch.tanh(self.last(x))
def test(): def test():
img_channels = 3 img_channels = 3
img_size = 256 img_size = 256
@@ -68,5 +118,6 @@ def test():
gen = Generator(img_channels, 9) gen = Generator(img_channels, 9)
print(gen(x).shape) print(gen(x).shape)
if __name__ == "__main__": if __name__ == "__main__":
test() test()

View File

@@ -1,3 +1,11 @@
"""
Training for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch import torch
from dataset import HorseZebraDataset from dataset import HorseZebraDataset
import sys import sys
@@ -11,7 +19,10 @@ from torchvision.utils import save_image
from discriminator_model import Discriminator from discriminator_model import Discriminator
from generator_model import Generator from generator_model import Generator
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
def train_fn(
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
H_reals = 0 H_reals = 0
H_fakes = 0 H_fakes = 0
loop = tqdm(loader, leave=True) loop = tqdm(loader, leave=True)
@@ -39,7 +50,7 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
D_Z_loss = D_Z_real_loss + D_Z_fake_loss D_Z_loss = D_Z_real_loss + D_Z_fake_loss
# put it togethor # put it togethor
D_loss = (D_H_loss + D_Z_loss)/2 D_loss = (D_H_loss + D_Z_loss) / 2
opt_disc.zero_grad() opt_disc.zero_grad()
d_scaler.scale(D_loss).backward() d_scaler.scale(D_loss).backward()
@@ -82,11 +93,10 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
g_scaler.update() g_scaler.update()
if idx % 200 == 0: if idx % 200 == 0:
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png") save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png") save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
def main(): def main():
@@ -111,23 +121,39 @@ def main():
if config.LOAD_MODEL: if config.LOAD_MODEL:
load_checkpoint( load_checkpoint(
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE, config.CHECKPOINT_GEN_H,
gen_H,
opt_gen,
config.LEARNING_RATE,
) )
load_checkpoint( load_checkpoint(
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE, config.CHECKPOINT_GEN_Z,
gen_Z,
opt_gen,
config.LEARNING_RATE,
) )
load_checkpoint( load_checkpoint(
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE, config.CHECKPOINT_CRITIC_H,
disc_H,
opt_disc,
config.LEARNING_RATE,
) )
load_checkpoint( load_checkpoint(
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE, config.CHECKPOINT_CRITIC_Z,
disc_Z,
opt_disc,
config.LEARNING_RATE,
) )
dataset = HorseZebraDataset( dataset = HorseZebraDataset(
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms root_horse=config.TRAIN_DIR + "/horses",
root_zebra=config.TRAIN_DIR + "/zebras",
transform=config.transforms,
) )
val_dataset = HorseZebraDataset( val_dataset = HorseZebraDataset(
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms root_horse="cyclegan_test/horse1",
root_zebra="cyclegan_test/zebra1",
transform=config.transforms,
) )
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
@@ -140,13 +166,25 @@ def main():
batch_size=config.BATCH_SIZE, batch_size=config.BATCH_SIZE,
shuffle=True, shuffle=True,
num_workers=config.NUM_WORKERS, num_workers=config.NUM_WORKERS,
pin_memory=True pin_memory=True,
) )
g_scaler = torch.cuda.amp.GradScaler() g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler() d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(config.NUM_EPOCHS): for epoch in range(config.NUM_EPOCHS):
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler) train_fn(
disc_H,
disc_Z,
gen_Z,
gen_H,
loader,
opt_disc,
opt_gen,
L1,
mse,
d_scaler,
g_scaler,
)
if config.SAVE_MODEL: if config.SAVE_MODEL:
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H) save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
@@ -154,5 +192,6 @@ def main():
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H) save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z) save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,131 +0,0 @@
"""
Example code of how to code GANs and more specifically DCGAN,
for more information about DCGANs read: https://arxiv.org/abs/1511.06434
We then train the DCGAN on the MNIST dataset (toy dataset of handwritten digits)
and then generate our own. You can apply this more generally on really any dataset
but MNIST is simple enough to get the overall idea.
Video explanation: https://youtu.be/5RYETbFFQ7s
Got any questions leave a comment on youtube :)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-04-20 Initial coding
"""
# Imports
import torch
import torchvision
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
from torch.utils.data import (
DataLoader,
) # Gives easier dataset managment and creates mini batches
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
from model_utils import (
Discriminator,
Generator,
) # Import our models we've defined (from DCGAN paper)
# Hyperparameters
lr = 0.0005
batch_size = 64
image_size = 64
channels_img = 1
channels_noise = 256
num_epochs = 10
# For how many channels Generator and Discriminator should use
features_d = 16
features_g = 16
my_transforms = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
dataset = datasets.MNIST(
root="dataset/", train=True, transform=my_transforms, download=True
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create discriminator and generator
netD = Discriminator(channels_img, features_d).to(device)
netG = Generator(channels_noise, channels_img, features_g).to(device)
# Setup Optimizer for G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
netG.train()
netD.train()
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
fixed_noise = torch.randn(64, channels_noise, 1, 1).to(device)
writer_real = SummaryWriter(f"runs/GAN_MNIST/test_real")
writer_fake = SummaryWriter(f"runs/GAN_MNIST/test_fake")
step = 0
print("Starting Training...")
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(dataloader):
data = data.to(device)
batch_size = data.shape[0]
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
netD.zero_grad()
label = (torch.ones(batch_size) * 0.9).to(device)
output = netD(data).reshape(-1)
lossD_real = criterion(output, label)
D_x = output.mean().item()
noise = torch.randn(batch_size, channels_noise, 1, 1).to(device)
fake = netG(noise)
label = (torch.ones(batch_size) * 0.1).to(device)
output = netD(fake.detach()).reshape(-1)
lossD_fake = criterion(output, label)
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()
### Train Generator: max log(D(G(z)))
netG.zero_grad()
label = torch.ones(batch_size).to(device)
output = netD(fake).reshape(-1)
lossG = criterion(output, label)
lossG.backward()
optimizerG.step()
# Print losses ocassionally and print to tensorboard
if batch_idx % 100 == 0:
step += 1
print(
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
Loss D: {lossD:.4f}, loss G: {lossG:.4f} D(x): {D_x:.4f}"
)
with torch.no_grad():
fake = netG(fixed_noise)
img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
writer_fake.add_image(
"Mnist Fake Images", img_grid_fake, global_step=step
)

View File

@@ -1,4 +0,0 @@
### Generative Adversarial Network
DCGAN_mnist.py: main file and training network
model_utils.py: Generator and discriminator implementation

View File

@@ -1,76 +0,0 @@
"""
Discriminator and Generator implementation from DCGAN paper
that we import in the main (DCGAN_mnist.py) file.
"""
import torch
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self, channels_img, features_d):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# N x channels_img x 64 x 64
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
# N x features_d x 32 x 32
nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(features_d * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(
features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_d * 4),
nn.LeakyReLU(0.2),
nn.Conv2d(
features_d * 4, features_d * 8, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_d * 8),
nn.LeakyReLU(0.2),
# N x features_d*8 x 4 x 4
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
# N x 1 x 1 x 1
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
class Generator(nn.Module):
def __init__(self, channels_noise, channels_img, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
# N x channels_noise x 1 x 1
nn.ConvTranspose2d(
channels_noise, features_g * 16, kernel_size=4, stride=1, padding=0
),
nn.BatchNorm2d(features_g * 16),
nn.ReLU(),
# N x features_g*16 x 4 x 4
nn.ConvTranspose2d(
features_g * 16, features_g * 8, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
),
# N x channels_img x 64 x 64
nn.Tanh(),
)
def forward(self, x):
return self.net(x)