mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
checked GAN code
This commit is contained in:
@@ -1,3 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Simple GAN using fully connected layers
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-01: Initial coding
|
||||||
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
@@ -48,7 +57,10 @@ disc = Discriminator(image_dim).to(device)
|
|||||||
gen = Generator(z_dim, image_dim).to(device)
|
gen = Generator(z_dim, image_dim).to(device)
|
||||||
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
|
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
|
||||||
transforms = transforms.Compose(
|
transforms = transforms.Compose(
|
||||||
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Discriminator and Generator implementation from DCGAN paper
|
Discriminator and Generator implementation from DCGAN paper
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-01: Initial coding
|
||||||
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,9 +15,7 @@ class Discriminator(nn.Module):
|
|||||||
super(Discriminator, self).__init__()
|
super(Discriminator, self).__init__()
|
||||||
self.disc = nn.Sequential(
|
self.disc = nn.Sequential(
|
||||||
# input: N x channels_img x 64 x 64
|
# input: N x channels_img x 64 x 64
|
||||||
nn.Conv2d(
|
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
|
||||||
channels_img, features_d, kernel_size=4, stride=2, padding=1
|
|
||||||
),
|
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
||||||
self._block(features_d, features_d * 2, 4, 2, 1),
|
self._block(features_d, features_d * 2, 4, 2, 1),
|
||||||
@@ -34,7 +36,7 @@ class Discriminator(nn.Module):
|
|||||||
padding,
|
padding,
|
||||||
bias=False,
|
bias=False,
|
||||||
),
|
),
|
||||||
#nn.BatchNorm2d(out_channels),
|
# nn.BatchNorm2d(out_channels),
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -68,7 +70,7 @@ class Generator(nn.Module):
|
|||||||
padding,
|
padding,
|
||||||
bias=False,
|
bias=False,
|
||||||
),
|
),
|
||||||
#nn.BatchNorm2d(out_channels),
|
# nn.BatchNorm2d(out_channels),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,6 +84,7 @@ def initialize_weights(model):
|
|||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
||||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
N, in_channels, H, W = 8, 3, 64, 64
|
N, in_channels, H, W = 8, 3, 64, 64
|
||||||
noise_dim = 100
|
noise_dim = 100
|
||||||
@@ -91,6 +94,8 @@ def test():
|
|||||||
gen = Generator(noise_dim, in_channels, 8)
|
gen = Generator(noise_dim, in_channels, 8)
|
||||||
z = torch.randn((N, noise_dim, 1, 1))
|
z = torch.randn((N, noise_dim, 1, 1))
|
||||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||||
|
print("Success, tests passed!")
|
||||||
|
|
||||||
|
|
||||||
# test()
|
if __name__ == "__main__":
|
||||||
|
test()
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Training of DCGAN network on MNIST dataset with Discriminator
|
Training of DCGAN network on MNIST dataset with Discriminator
|
||||||
and Generator imported from models.py
|
and Generator imported from models.py
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-01: Initial coding
|
||||||
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -35,11 +39,12 @@ transforms = transforms.Compose(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If you train on MNIST, remember to set channels_img to 1
|
# If you train on MNIST, remember to set channels_img to 1
|
||||||
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
|
dataset = datasets.MNIST(
|
||||||
download=True)
|
root="dataset/", train=True, transform=transforms, download=True
|
||||||
|
)
|
||||||
|
|
||||||
# comment mnist above and uncomment below if train on CelebA
|
# comment mnist above and uncomment below if train on CelebA
|
||||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
|
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
|
||||||
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
|
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
|
||||||
@@ -92,12 +97,8 @@ for epoch in range(NUM_EPOCHS):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
fake = gen(fixed_noise)
|
fake = gen(fixed_noise)
|
||||||
# take out (up to) 32 examples
|
# take out (up to) 32 examples
|
||||||
img_grid_real = torchvision.utils.make_grid(
|
img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
|
||||||
real[:32], normalize=True
|
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
|
||||||
)
|
|
||||||
img_grid_fake = torchvision.utils.make_grid(
|
|
||||||
fake[:32], normalize=True
|
|
||||||
)
|
|
||||||
|
|
||||||
writer_real.add_image("Real", img_grid_real, global_step=step)
|
writer_real.add_image("Real", img_grid_real, global_step=step)
|
||||||
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
||||||
|
|||||||
@@ -2,6 +2,10 @@
|
|||||||
Discriminator and Generator implementation from DCGAN paper,
|
Discriminator and Generator implementation from DCGAN paper,
|
||||||
with removed Sigmoid() as output from Discriminator (and therefor
|
with removed Sigmoid() as output from Discriminator (and therefor
|
||||||
it should be called critic)
|
it should be called critic)
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-01: Initial coding
|
||||||
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -93,6 +97,7 @@ def test():
|
|||||||
gen = Generator(noise_dim, in_channels, 8)
|
gen = Generator(noise_dim, in_channels, 8)
|
||||||
z = torch.randn((N, noise_dim, 1, 1))
|
z = torch.randn((N, noise_dim, 1, 1))
|
||||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||||
|
print("Success, tests passed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
# test()
|
test()
|
||||||
@@ -1,5 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Training of DCGAN network with WGAN loss
|
Training of DCGAN network with WGAN loss
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-01: Initial coding
|
||||||
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -9,6 +13,7 @@ import torchvision
|
|||||||
import torchvision.datasets as datasets
|
import torchvision.datasets as datasets
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from model import Discriminator, Generator, initialize_weights
|
from model import Discriminator, Generator, initialize_weights
|
||||||
|
|
||||||
@@ -61,7 +66,7 @@ critic.train()
|
|||||||
|
|
||||||
for epoch in range(NUM_EPOCHS):
|
for epoch in range(NUM_EPOCHS):
|
||||||
# Target labels not needed! <3 unsupervised
|
# Target labels not needed! <3 unsupervised
|
||||||
for batch_idx, (data, _) in enumerate(loader):
|
for batch_idx, (data, _) in enumerate(tqdm(loader)):
|
||||||
data = data.to(device)
|
data = data.to(device)
|
||||||
cur_batch_size = data.shape[0]
|
cur_batch_size = data.shape[0]
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Discriminator and Generator implementation from DCGAN paper
|
Discriminator and Generator implementation from DCGAN paper
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-01: Initial coding
|
||||||
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -24,7 +28,12 @@ class Discriminator(nn.Module):
|
|||||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels, out_channels, kernel_size, stride, padding, bias=False,
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
bias=False,
|
||||||
),
|
),
|
||||||
nn.InstanceNorm2d(out_channels, affine=True),
|
nn.InstanceNorm2d(out_channels, affine=True),
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
@@ -53,7 +62,12 @@ class Generator(nn.Module):
|
|||||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.ConvTranspose2d(
|
nn.ConvTranspose2d(
|
||||||
in_channels, out_channels, kernel_size, stride, padding, bias=False,
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
bias=False,
|
||||||
),
|
),
|
||||||
nn.BatchNorm2d(out_channels),
|
nn.BatchNorm2d(out_channels),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Training of WGAN-GP
|
Training of WGAN-GP
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-01: Initial coding
|
||||||
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -10,6 +14,7 @@ import torchvision.datasets as datasets
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
from utils import gradient_penalty, save_checkpoint, load_checkpoint
|
from utils import gradient_penalty, save_checkpoint, load_checkpoint
|
||||||
from model import Discriminator, Generator, initialize_weights
|
from model import Discriminator, Generator, initialize_weights
|
||||||
|
|
||||||
@@ -31,13 +36,14 @@ transforms = transforms.Compose(
|
|||||||
transforms.Resize(IMAGE_SIZE),
|
transforms.Resize(IMAGE_SIZE),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
|
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
||||||
# comment mnist above and uncomment below for training on CelebA
|
# comment mnist above and uncomment below for training on CelebA
|
||||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
@@ -66,7 +72,7 @@ critic.train()
|
|||||||
|
|
||||||
for epoch in range(NUM_EPOCHS):
|
for epoch in range(NUM_EPOCHS):
|
||||||
# Target labels not needed! <3 unsupervised
|
# Target labels not needed! <3 unsupervised
|
||||||
for batch_idx, (real, _) in enumerate(loader):
|
for batch_idx, (real, _) in enumerate(tqdm(loader)):
|
||||||
real = real.to(device)
|
real = real.to(device)
|
||||||
cur_batch_size = real.shape[0]
|
cur_batch_size = real.shape[0]
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ LAMBDA_IDENTITY = 0.0
|
|||||||
LAMBDA_CYCLE = 10
|
LAMBDA_CYCLE = 10
|
||||||
NUM_WORKERS = 4
|
NUM_WORKERS = 4
|
||||||
NUM_EPOCHS = 10
|
NUM_EPOCHS = 10
|
||||||
LOAD_MODEL = True
|
LOAD_MODEL = False
|
||||||
SAVE_MODEL = True
|
SAVE_MODEL = True
|
||||||
CHECKPOINT_GEN_H = "genh.pth.tar"
|
CHECKPOINT_GEN_H = "genh.pth.tar"
|
||||||
CHECKPOINT_GEN_Z = "genz.pth.tar"
|
CHECKPOINT_GEN_Z = "genz.pth.tar"
|
||||||
|
|||||||
@@ -1,11 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Discriminator model for CycleGAN
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-05: Initial coding
|
||||||
|
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, stride):
|
def __init__(self, in_channels, out_channels, stride):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
4,
|
||||||
|
stride,
|
||||||
|
1,
|
||||||
|
bias=True,
|
||||||
|
padding_mode="reflect",
|
||||||
|
),
|
||||||
nn.InstanceNorm2d(out_channels),
|
nn.InstanceNorm2d(out_channels),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
)
|
)
|
||||||
@@ -32,15 +49,27 @@ class Discriminator(nn.Module):
|
|||||||
layers = []
|
layers = []
|
||||||
in_channels = features[0]
|
in_channels = features[0]
|
||||||
for feature in features[1:]:
|
for feature in features[1:]:
|
||||||
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
|
layers.append(
|
||||||
|
Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
|
||||||
|
)
|
||||||
in_channels = feature
|
in_channels = feature
|
||||||
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
|
layers.append(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
1,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
padding_mode="reflect",
|
||||||
|
)
|
||||||
|
)
|
||||||
self.model = nn.Sequential(*layers)
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.initial(x)
|
x = self.initial(x)
|
||||||
return torch.sigmoid(self.model(x))
|
return torch.sigmoid(self.model(x))
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
x = torch.randn((5, 3, 256, 256))
|
x = torch.randn((5, 3, 256, 256))
|
||||||
model = Discriminator(in_channels=3)
|
model = Discriminator(in_channels=3)
|
||||||
@@ -50,4 +79,3 @@ def test():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test()
|
test()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Generator model for CycleGAN
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-05: Initial coding
|
||||||
|
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
|
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -9,12 +18,13 @@ class ConvBlock(nn.Module):
|
|||||||
if down
|
if down
|
||||||
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
|
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
|
||||||
nn.InstanceNorm2d(out_channels),
|
nn.InstanceNorm2d(out_channels),
|
||||||
nn.ReLU(inplace=True) if use_act else nn.Identity()
|
nn.ReLU(inplace=True) if use_act else nn.Identity(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, channels):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -26,31 +36,70 @@ class ResidualBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.block(x)
|
return x + self.block(x)
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, img_channels, num_features = 64, num_residuals=9):
|
def __init__(self, img_channels, num_features=64, num_residuals=9):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.initial = nn.Sequential(
|
self.initial = nn.Sequential(
|
||||||
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
|
nn.Conv2d(
|
||||||
|
img_channels,
|
||||||
|
num_features,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
|
padding_mode="reflect",
|
||||||
|
),
|
||||||
nn.InstanceNorm2d(num_features),
|
nn.InstanceNorm2d(num_features),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
)
|
)
|
||||||
self.down_blocks = nn.ModuleList(
|
self.down_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
|
ConvBlock(
|
||||||
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
|
num_features, num_features * 2, kernel_size=3, stride=2, padding=1
|
||||||
|
),
|
||||||
|
ConvBlock(
|
||||||
|
num_features * 2,
|
||||||
|
num_features * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.res_blocks = nn.Sequential(
|
self.res_blocks = nn.Sequential(
|
||||||
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
|
*[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
|
||||||
)
|
)
|
||||||
self.up_blocks = nn.ModuleList(
|
self.up_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
ConvBlock(
|
||||||
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
num_features * 4,
|
||||||
|
num_features * 2,
|
||||||
|
down=False,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
),
|
||||||
|
ConvBlock(
|
||||||
|
num_features * 2,
|
||||||
|
num_features * 1,
|
||||||
|
down=False,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
|
self.last = nn.Conv2d(
|
||||||
|
num_features * 1,
|
||||||
|
img_channels,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
|
padding_mode="reflect",
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.initial(x)
|
x = self.initial(x)
|
||||||
@@ -61,6 +110,7 @@ class Generator(nn.Module):
|
|||||||
x = layer(x)
|
x = layer(x)
|
||||||
return torch.tanh(self.last(x))
|
return torch.tanh(self.last(x))
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
img_channels = 3
|
img_channels = 3
|
||||||
img_size = 256
|
img_size = 256
|
||||||
@@ -68,5 +118,6 @@ def test():
|
|||||||
gen = Generator(img_channels, 9)
|
gen = Generator(img_channels, 9)
|
||||||
print(gen(x).shape)
|
print(gen(x).shape)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test()
|
test()
|
||||||
|
|||||||
@@ -1,3 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Training for CycleGAN
|
||||||
|
|
||||||
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||||
|
* 2020-11-05: Initial coding
|
||||||
|
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from dataset import HorseZebraDataset
|
from dataset import HorseZebraDataset
|
||||||
import sys
|
import sys
|
||||||
@@ -11,7 +19,10 @@ from torchvision.utils import save_image
|
|||||||
from discriminator_model import Discriminator
|
from discriminator_model import Discriminator
|
||||||
from generator_model import Generator
|
from generator_model import Generator
|
||||||
|
|
||||||
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
|
|
||||||
|
def train_fn(
|
||||||
|
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
|
||||||
|
):
|
||||||
H_reals = 0
|
H_reals = 0
|
||||||
H_fakes = 0
|
H_fakes = 0
|
||||||
loop = tqdm(loader, leave=True)
|
loop = tqdm(loader, leave=True)
|
||||||
@@ -39,7 +50,7 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
|
|||||||
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
||||||
|
|
||||||
# put it togethor
|
# put it togethor
|
||||||
D_loss = (D_H_loss + D_Z_loss)/2
|
D_loss = (D_H_loss + D_Z_loss) / 2
|
||||||
|
|
||||||
opt_disc.zero_grad()
|
opt_disc.zero_grad()
|
||||||
d_scaler.scale(D_loss).backward()
|
d_scaler.scale(D_loss).backward()
|
||||||
@@ -82,11 +93,10 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
|
|||||||
g_scaler.update()
|
g_scaler.update()
|
||||||
|
|
||||||
if idx % 200 == 0:
|
if idx % 200 == 0:
|
||||||
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
|
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
|
||||||
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
|
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
|
||||||
|
|
||||||
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
|
|
||||||
|
|
||||||
|
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -111,23 +121,39 @@ def main():
|
|||||||
|
|
||||||
if config.LOAD_MODEL:
|
if config.LOAD_MODEL:
|
||||||
load_checkpoint(
|
load_checkpoint(
|
||||||
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
|
config.CHECKPOINT_GEN_H,
|
||||||
|
gen_H,
|
||||||
|
opt_gen,
|
||||||
|
config.LEARNING_RATE,
|
||||||
)
|
)
|
||||||
load_checkpoint(
|
load_checkpoint(
|
||||||
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
|
config.CHECKPOINT_GEN_Z,
|
||||||
|
gen_Z,
|
||||||
|
opt_gen,
|
||||||
|
config.LEARNING_RATE,
|
||||||
)
|
)
|
||||||
load_checkpoint(
|
load_checkpoint(
|
||||||
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
|
config.CHECKPOINT_CRITIC_H,
|
||||||
|
disc_H,
|
||||||
|
opt_disc,
|
||||||
|
config.LEARNING_RATE,
|
||||||
)
|
)
|
||||||
load_checkpoint(
|
load_checkpoint(
|
||||||
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
|
config.CHECKPOINT_CRITIC_Z,
|
||||||
|
disc_Z,
|
||||||
|
opt_disc,
|
||||||
|
config.LEARNING_RATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = HorseZebraDataset(
|
dataset = HorseZebraDataset(
|
||||||
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
|
root_horse=config.TRAIN_DIR + "/horses",
|
||||||
|
root_zebra=config.TRAIN_DIR + "/zebras",
|
||||||
|
transform=config.transforms,
|
||||||
)
|
)
|
||||||
val_dataset = HorseZebraDataset(
|
val_dataset = HorseZebraDataset(
|
||||||
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
|
root_horse="cyclegan_test/horse1",
|
||||||
|
root_zebra="cyclegan_test/zebra1",
|
||||||
|
transform=config.transforms,
|
||||||
)
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
@@ -140,13 +166,25 @@ def main():
|
|||||||
batch_size=config.BATCH_SIZE,
|
batch_size=config.BATCH_SIZE,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=config.NUM_WORKERS,
|
num_workers=config.NUM_WORKERS,
|
||||||
pin_memory=True
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
g_scaler = torch.cuda.amp.GradScaler()
|
g_scaler = torch.cuda.amp.GradScaler()
|
||||||
d_scaler = torch.cuda.amp.GradScaler()
|
d_scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
for epoch in range(config.NUM_EPOCHS):
|
for epoch in range(config.NUM_EPOCHS):
|
||||||
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
|
train_fn(
|
||||||
|
disc_H,
|
||||||
|
disc_Z,
|
||||||
|
gen_Z,
|
||||||
|
gen_H,
|
||||||
|
loader,
|
||||||
|
opt_disc,
|
||||||
|
opt_gen,
|
||||||
|
L1,
|
||||||
|
mse,
|
||||||
|
d_scaler,
|
||||||
|
g_scaler,
|
||||||
|
)
|
||||||
|
|
||||||
if config.SAVE_MODEL:
|
if config.SAVE_MODEL:
|
||||||
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
|
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
|
||||||
@@ -154,5 +192,6 @@ def main():
|
|||||||
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
|
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
|
||||||
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
|
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""
|
|
||||||
Example code of how to code GANs and more specifically DCGAN,
|
|
||||||
for more information about DCGANs read: https://arxiv.org/abs/1511.06434
|
|
||||||
|
|
||||||
We then train the DCGAN on the MNIST dataset (toy dataset of handwritten digits)
|
|
||||||
and then generate our own. You can apply this more generally on really any dataset
|
|
||||||
but MNIST is simple enough to get the overall idea.
|
|
||||||
|
|
||||||
Video explanation: https://youtu.be/5RYETbFFQ7s
|
|
||||||
Got any questions leave a comment on youtube :)
|
|
||||||
|
|
||||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
|
||||||
* 2020-04-20 Initial coding
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Imports
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
|
|
||||||
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
|
|
||||||
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
|
|
||||||
import torchvision.transforms as transforms # Transformations we can perform on our dataset
|
|
||||||
from torch.utils.data import (
|
|
||||||
DataLoader,
|
|
||||||
) # Gives easier dataset managment and creates mini batches
|
|
||||||
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
|
|
||||||
from model_utils import (
|
|
||||||
Discriminator,
|
|
||||||
Generator,
|
|
||||||
) # Import our models we've defined (from DCGAN paper)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
lr = 0.0005
|
|
||||||
batch_size = 64
|
|
||||||
image_size = 64
|
|
||||||
channels_img = 1
|
|
||||||
channels_noise = 256
|
|
||||||
num_epochs = 10
|
|
||||||
|
|
||||||
# For how many channels Generator and Discriminator should use
|
|
||||||
features_d = 16
|
|
||||||
features_g = 16
|
|
||||||
|
|
||||||
my_transforms = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.Resize(image_size),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.5,), (0.5,)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = datasets.MNIST(
|
|
||||||
root="dataset/", train=True, transform=my_transforms, download=True
|
|
||||||
)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# Create discriminator and generator
|
|
||||||
netD = Discriminator(channels_img, features_d).to(device)
|
|
||||||
netG = Generator(channels_noise, channels_img, features_g).to(device)
|
|
||||||
|
|
||||||
# Setup Optimizer for G and D
|
|
||||||
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
|
|
||||||
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
|
|
||||||
|
|
||||||
netG.train()
|
|
||||||
netD.train()
|
|
||||||
|
|
||||||
criterion = nn.BCELoss()
|
|
||||||
|
|
||||||
real_label = 1
|
|
||||||
fake_label = 0
|
|
||||||
|
|
||||||
fixed_noise = torch.randn(64, channels_noise, 1, 1).to(device)
|
|
||||||
writer_real = SummaryWriter(f"runs/GAN_MNIST/test_real")
|
|
||||||
writer_fake = SummaryWriter(f"runs/GAN_MNIST/test_fake")
|
|
||||||
step = 0
|
|
||||||
|
|
||||||
print("Starting Training...")
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
for batch_idx, (data, targets) in enumerate(dataloader):
|
|
||||||
data = data.to(device)
|
|
||||||
batch_size = data.shape[0]
|
|
||||||
|
|
||||||
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
|
|
||||||
netD.zero_grad()
|
|
||||||
label = (torch.ones(batch_size) * 0.9).to(device)
|
|
||||||
output = netD(data).reshape(-1)
|
|
||||||
lossD_real = criterion(output, label)
|
|
||||||
D_x = output.mean().item()
|
|
||||||
|
|
||||||
noise = torch.randn(batch_size, channels_noise, 1, 1).to(device)
|
|
||||||
fake = netG(noise)
|
|
||||||
label = (torch.ones(batch_size) * 0.1).to(device)
|
|
||||||
|
|
||||||
output = netD(fake.detach()).reshape(-1)
|
|
||||||
lossD_fake = criterion(output, label)
|
|
||||||
|
|
||||||
lossD = lossD_real + lossD_fake
|
|
||||||
lossD.backward()
|
|
||||||
optimizerD.step()
|
|
||||||
|
|
||||||
### Train Generator: max log(D(G(z)))
|
|
||||||
netG.zero_grad()
|
|
||||||
label = torch.ones(batch_size).to(device)
|
|
||||||
output = netD(fake).reshape(-1)
|
|
||||||
lossG = criterion(output, label)
|
|
||||||
lossG.backward()
|
|
||||||
optimizerG.step()
|
|
||||||
|
|
||||||
# Print losses ocassionally and print to tensorboard
|
|
||||||
if batch_idx % 100 == 0:
|
|
||||||
step += 1
|
|
||||||
print(
|
|
||||||
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
|
|
||||||
Loss D: {lossD:.4f}, loss G: {lossG:.4f} D(x): {D_x:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
fake = netG(fixed_noise)
|
|
||||||
img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)
|
|
||||||
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
|
|
||||||
writer_real.add_image(
|
|
||||||
"Mnist Real Images", img_grid_real, global_step=step
|
|
||||||
)
|
|
||||||
writer_fake.add_image(
|
|
||||||
"Mnist Fake Images", img_grid_fake, global_step=step
|
|
||||||
)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
### Generative Adversarial Network
|
|
||||||
|
|
||||||
DCGAN_mnist.py: main file and training network
|
|
||||||
model_utils.py: Generator and discriminator implementation
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
"""
|
|
||||||
Discriminator and Generator implementation from DCGAN paper
|
|
||||||
that we import in the main (DCGAN_mnist.py) file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class Discriminator(nn.Module):
|
|
||||||
def __init__(self, channels_img, features_d):
|
|
||||||
super(Discriminator, self).__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
# N x channels_img x 64 x 64
|
|
||||||
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
# N x features_d x 32 x 32
|
|
||||||
nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
|
|
||||||
nn.BatchNorm2d(features_d * 2),
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
nn.Conv2d(
|
|
||||||
features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(features_d * 4),
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
nn.Conv2d(
|
|
||||||
features_d * 4, features_d * 8, kernel_size=4, stride=2, padding=1
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(features_d * 8),
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
# N x features_d*8 x 4 x 4
|
|
||||||
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
|
|
||||||
# N x 1 x 1 x 1
|
|
||||||
nn.Sigmoid(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, channels_noise, channels_img, features_g):
|
|
||||||
super(Generator, self).__init__()
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
# N x channels_noise x 1 x 1
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
channels_noise, features_g * 16, kernel_size=4, stride=1, padding=0
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(features_g * 16),
|
|
||||||
nn.ReLU(),
|
|
||||||
# N x features_g*16 x 4 x 4
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
features_g * 16, features_g * 8, kernel_size=4, stride=2, padding=1
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(features_g * 8),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(features_g * 4),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(features_g * 2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
|
|
||||||
),
|
|
||||||
# N x channels_img x 64 x 64
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
Reference in New Issue
Block a user