mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
checked GAN code
This commit is contained in:
@@ -1,3 +1,12 @@
|
||||
"""
|
||||
Simple GAN using fully connected layers
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
@@ -48,7 +57,10 @@ disc = Discriminator(image_dim).to(device)
|
||||
gen = Generator(z_dim, image_dim).to(device)
|
||||
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
|
||||
transforms = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
||||
@@ -104,4 +116,4 @@ for epoch in range(num_epochs):
|
||||
writer_real.add_image(
|
||||
"Mnist Real Images", img_grid_real, global_step=step
|
||||
)
|
||||
step += 1
|
||||
step += 1
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
Discriminator and Generator implementation from DCGAN paper
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -11,9 +15,7 @@ class Discriminator(nn.Module):
|
||||
super(Discriminator, self).__init__()
|
||||
self.disc = nn.Sequential(
|
||||
# input: N x channels_img x 64 x 64
|
||||
nn.Conv2d(
|
||||
channels_img, features_d, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self._block(features_d, features_d * 2, 4, 2, 1),
|
||||
@@ -34,7 +36,7 @@ class Discriminator(nn.Module):
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
#nn.BatchNorm2d(out_channels),
|
||||
# nn.BatchNorm2d(out_channels),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
@@ -68,7 +70,7 @@ class Generator(nn.Module):
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
#nn.BatchNorm2d(out_channels),
|
||||
# nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
@@ -82,6 +84,7 @@ def initialize_weights(model):
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
|
||||
|
||||
def test():
|
||||
N, in_channels, H, W = 8, 3, 64, 64
|
||||
noise_dim = 100
|
||||
@@ -91,6 +94,8 @@ def test():
|
||||
gen = Generator(noise_dim, in_channels, 8)
|
||||
z = torch.randn((N, noise_dim, 1, 1))
|
||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||
print("Success, tests passed!")
|
||||
|
||||
|
||||
# test()
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
Training of DCGAN network on MNIST dataset with Discriminator
|
||||
and Generator imported from models.py
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -35,11 +39,12 @@ transforms = transforms.Compose(
|
||||
)
|
||||
|
||||
# If you train on MNIST, remember to set channels_img to 1
|
||||
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
|
||||
download=True)
|
||||
dataset = datasets.MNIST(
|
||||
root="dataset/", train=True, transform=transforms, download=True
|
||||
)
|
||||
|
||||
# comment mnist above and uncomment below if train on CelebA
|
||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
|
||||
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
|
||||
@@ -92,14 +97,10 @@ for epoch in range(NUM_EPOCHS):
|
||||
with torch.no_grad():
|
||||
fake = gen(fixed_noise)
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(
|
||||
real[:32], normalize=True
|
||||
)
|
||||
img_grid_fake = torchvision.utils.make_grid(
|
||||
fake[:32], normalize=True
|
||||
)
|
||||
img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
|
||||
|
||||
writer_real.add_image("Real", img_grid_real, global_step=step)
|
||||
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
||||
|
||||
step += 1
|
||||
step += 1
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
Discriminator and Generator implementation from DCGAN paper,
|
||||
with removed Sigmoid() as output from Discriminator (and therefor
|
||||
it should be called critic)
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -93,6 +97,7 @@ def test():
|
||||
gen = Generator(noise_dim, in_channels, 8)
|
||||
z = torch.randn((N, noise_dim, 1, 1))
|
||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||
print("Success, tests passed!")
|
||||
|
||||
|
||||
# test()
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
Training of DCGAN network with WGAN loss
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -9,6 +13,7 @@ import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from model import Discriminator, Generator, initialize_weights
|
||||
|
||||
@@ -61,7 +66,7 @@ critic.train()
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
# Target labels not needed! <3 unsupervised
|
||||
for batch_idx, (data, _) in enumerate(loader):
|
||||
for batch_idx, (data, _) in enumerate(tqdm(loader)):
|
||||
data = data.to(device)
|
||||
cur_batch_size = data.shape[0]
|
||||
|
||||
@@ -111,4 +116,4 @@ for epoch in range(NUM_EPOCHS):
|
||||
|
||||
step += 1
|
||||
gen.train()
|
||||
critic.train()
|
||||
critic.train()
|
||||
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
Discriminator and Generator implementation from DCGAN paper
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -24,7 +28,12 @@ class Discriminator(nn.Module):
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias=False,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
nn.InstanceNorm2d(out_channels, affine=True),
|
||||
nn.LeakyReLU(0.2),
|
||||
@@ -53,7 +62,12 @@ class Generator(nn.Module):
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias=False,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
Training of WGAN-GP
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -10,6 +14,7 @@ import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from utils import gradient_penalty, save_checkpoint, load_checkpoint
|
||||
from model import Discriminator, Generator, initialize_weights
|
||||
|
||||
@@ -31,13 +36,14 @@ transforms = transforms.Compose(
|
||||
transforms.Resize(IMAGE_SIZE),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
|
||||
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
||||
# comment mnist above and uncomment below for training on CelebA
|
||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
@@ -66,7 +72,7 @@ critic.train()
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
# Target labels not needed! <3 unsupervised
|
||||
for batch_idx, (real, _) in enumerate(loader):
|
||||
for batch_idx, (real, _) in enumerate(tqdm(loader)):
|
||||
real = real.to(device)
|
||||
cur_batch_size = real.shape[0]
|
||||
|
||||
@@ -108,4 +114,4 @@ for epoch in range(NUM_EPOCHS):
|
||||
writer_real.add_image("Real", img_grid_real, global_step=step)
|
||||
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
||||
|
||||
step += 1
|
||||
step += 1
|
||||
|
||||
@@ -11,7 +11,7 @@ LAMBDA_IDENTITY = 0.0
|
||||
LAMBDA_CYCLE = 10
|
||||
NUM_WORKERS = 4
|
||||
NUM_EPOCHS = 10
|
||||
LOAD_MODEL = True
|
||||
LOAD_MODEL = False
|
||||
SAVE_MODEL = True
|
||||
CHECKPOINT_GEN_H = "genh.pth.tar"
|
||||
CHECKPOINT_GEN_Z = "genz.pth.tar"
|
||||
@@ -24,6 +24,6 @@ transforms = A.Compose(
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
],
|
||||
additional_targets={"image0": "image"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,28 @@
|
||||
"""
|
||||
Discriminator model for CycleGAN
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-05: Initial coding
|
||||
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
4,
|
||||
stride,
|
||||
1,
|
||||
bias=True,
|
||||
padding_mode="reflect",
|
||||
),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
@@ -32,15 +49,27 @@ class Discriminator(nn.Module):
|
||||
layers = []
|
||||
in_channels = features[0]
|
||||
for feature in features[1:]:
|
||||
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
|
||||
layers.append(
|
||||
Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
|
||||
)
|
||||
in_channels = feature
|
||||
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
|
||||
layers.append(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
1,
|
||||
kernel_size=4,
|
||||
stride=1,
|
||||
padding=1,
|
||||
padding_mode="reflect",
|
||||
)
|
||||
)
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial(x)
|
||||
return torch.sigmoid(self.model(x))
|
||||
|
||||
|
||||
def test():
|
||||
x = torch.randn((5, 3, 256, 256))
|
||||
model = Discriminator(in_channels=3)
|
||||
@@ -50,4 +79,3 @@ def test():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
"""
|
||||
Generator model for CycleGAN
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-05: Initial coding
|
||||
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
|
||||
super().__init__()
|
||||
@@ -9,12 +18,13 @@ class ConvBlock(nn.Module):
|
||||
if down
|
||||
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True) if use_act else nn.Identity()
|
||||
nn.ReLU(inplace=True) if use_act else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
@@ -26,31 +36,70 @@ class ResidualBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, img_channels, num_features = 64, num_residuals=9):
|
||||
def __init__(self, img_channels, num_features=64, num_residuals=9):
|
||||
super().__init__()
|
||||
self.initial = nn.Sequential(
|
||||
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
|
||||
nn.Conv2d(
|
||||
img_channels,
|
||||
num_features,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
padding_mode="reflect",
|
||||
),
|
||||
nn.InstanceNorm2d(num_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.down_blocks = nn.ModuleList(
|
||||
[
|
||||
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
|
||||
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
|
||||
ConvBlock(
|
||||
num_features, num_features * 2, kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
ConvBlock(
|
||||
num_features * 2,
|
||||
num_features * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.res_blocks = nn.Sequential(
|
||||
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
|
||||
*[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
|
||||
)
|
||||
self.up_blocks = nn.ModuleList(
|
||||
[
|
||||
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
ConvBlock(
|
||||
num_features * 4,
|
||||
num_features * 2,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
),
|
||||
ConvBlock(
|
||||
num_features * 2,
|
||||
num_features * 1,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
|
||||
self.last = nn.Conv2d(
|
||||
num_features * 1,
|
||||
img_channels,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
padding_mode="reflect",
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial(x)
|
||||
@@ -61,6 +110,7 @@ class Generator(nn.Module):
|
||||
x = layer(x)
|
||||
return torch.tanh(self.last(x))
|
||||
|
||||
|
||||
def test():
|
||||
img_channels = 3
|
||||
img_size = 256
|
||||
@@ -68,5 +118,6 @@ def test():
|
||||
gen = Generator(img_channels, 9)
|
||||
print(gen(x).shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
"""
|
||||
Training for CycleGAN
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-05: Initial coding
|
||||
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
from dataset import HorseZebraDataset
|
||||
import sys
|
||||
@@ -11,7 +19,10 @@ from torchvision.utils import save_image
|
||||
from discriminator_model import Discriminator
|
||||
from generator_model import Generator
|
||||
|
||||
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
|
||||
|
||||
def train_fn(
|
||||
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
|
||||
):
|
||||
H_reals = 0
|
||||
H_fakes = 0
|
||||
loop = tqdm(loader, leave=True)
|
||||
@@ -39,7 +50,7 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
|
||||
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
||||
|
||||
# put it togethor
|
||||
D_loss = (D_H_loss + D_Z_loss)/2
|
||||
D_loss = (D_H_loss + D_Z_loss) / 2
|
||||
|
||||
opt_disc.zero_grad()
|
||||
d_scaler.scale(D_loss).backward()
|
||||
@@ -82,11 +93,10 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
|
||||
g_scaler.update()
|
||||
|
||||
if idx % 200 == 0:
|
||||
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
|
||||
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
|
||||
|
||||
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
|
||||
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
|
||||
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
|
||||
|
||||
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
|
||||
|
||||
|
||||
def main():
|
||||
@@ -111,23 +121,39 @@ def main():
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_GEN_H,
|
||||
gen_H,
|
||||
opt_gen,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_GEN_Z,
|
||||
gen_Z,
|
||||
opt_gen,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_CRITIC_H,
|
||||
disc_H,
|
||||
opt_disc,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
load_checkpoint(
|
||||
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
|
||||
config.CHECKPOINT_CRITIC_Z,
|
||||
disc_Z,
|
||||
opt_disc,
|
||||
config.LEARNING_RATE,
|
||||
)
|
||||
|
||||
dataset = HorseZebraDataset(
|
||||
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
|
||||
root_horse=config.TRAIN_DIR + "/horses",
|
||||
root_zebra=config.TRAIN_DIR + "/zebras",
|
||||
transform=config.transforms,
|
||||
)
|
||||
val_dataset = HorseZebraDataset(
|
||||
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
|
||||
root_horse="cyclegan_test/horse1",
|
||||
root_zebra="cyclegan_test/zebra1",
|
||||
transform=config.transforms,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
@@ -140,13 +166,25 @@ def main():
|
||||
batch_size=config.BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=config.NUM_WORKERS,
|
||||
pin_memory=True
|
||||
pin_memory=True,
|
||||
)
|
||||
g_scaler = torch.cuda.amp.GradScaler()
|
||||
d_scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
for epoch in range(config.NUM_EPOCHS):
|
||||
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
|
||||
train_fn(
|
||||
disc_H,
|
||||
disc_Z,
|
||||
gen_Z,
|
||||
gen_H,
|
||||
loader,
|
||||
opt_disc,
|
||||
opt_gen,
|
||||
L1,
|
||||
mse,
|
||||
d_scaler,
|
||||
g_scaler,
|
||||
)
|
||||
|
||||
if config.SAVE_MODEL:
|
||||
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
|
||||
@@ -154,5 +192,6 @@ def main():
|
||||
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
|
||||
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
"""
|
||||
Example code of how to code GANs and more specifically DCGAN,
|
||||
for more information about DCGANs read: https://arxiv.org/abs/1511.06434
|
||||
|
||||
We then train the DCGAN on the MNIST dataset (toy dataset of handwritten digits)
|
||||
and then generate our own. You can apply this more generally on really any dataset
|
||||
but MNIST is simple enough to get the overall idea.
|
||||
|
||||
Video explanation: https://youtu.be/5RYETbFFQ7s
|
||||
Got any questions leave a comment on youtube :)
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-04-20 Initial coding
|
||||
|
||||
"""
|
||||
|
||||
# Imports
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
|
||||
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
|
||||
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
|
||||
import torchvision.transforms as transforms # Transformations we can perform on our dataset
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
) # Gives easier dataset managment and creates mini batches
|
||||
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
|
||||
from model_utils import (
|
||||
Discriminator,
|
||||
Generator,
|
||||
) # Import our models we've defined (from DCGAN paper)
|
||||
|
||||
# Hyperparameters
|
||||
lr = 0.0005
|
||||
batch_size = 64
|
||||
image_size = 64
|
||||
channels_img = 1
|
||||
channels_noise = 256
|
||||
num_epochs = 10
|
||||
|
||||
# For how many channels Generator and Discriminator should use
|
||||
features_d = 16
|
||||
features_g = 16
|
||||
|
||||
my_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = datasets.MNIST(
|
||||
root="dataset/", train=True, transform=my_transforms, download=True
|
||||
)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Create discriminator and generator
|
||||
netD = Discriminator(channels_img, features_d).to(device)
|
||||
netG = Generator(channels_noise, channels_img, features_g).to(device)
|
||||
|
||||
# Setup Optimizer for G and D
|
||||
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
|
||||
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
|
||||
|
||||
netG.train()
|
||||
netD.train()
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
|
||||
real_label = 1
|
||||
fake_label = 0
|
||||
|
||||
fixed_noise = torch.randn(64, channels_noise, 1, 1).to(device)
|
||||
writer_real = SummaryWriter(f"runs/GAN_MNIST/test_real")
|
||||
writer_fake = SummaryWriter(f"runs/GAN_MNIST/test_fake")
|
||||
step = 0
|
||||
|
||||
print("Starting Training...")
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, (data, targets) in enumerate(dataloader):
|
||||
data = data.to(device)
|
||||
batch_size = data.shape[0]
|
||||
|
||||
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
|
||||
netD.zero_grad()
|
||||
label = (torch.ones(batch_size) * 0.9).to(device)
|
||||
output = netD(data).reshape(-1)
|
||||
lossD_real = criterion(output, label)
|
||||
D_x = output.mean().item()
|
||||
|
||||
noise = torch.randn(batch_size, channels_noise, 1, 1).to(device)
|
||||
fake = netG(noise)
|
||||
label = (torch.ones(batch_size) * 0.1).to(device)
|
||||
|
||||
output = netD(fake.detach()).reshape(-1)
|
||||
lossD_fake = criterion(output, label)
|
||||
|
||||
lossD = lossD_real + lossD_fake
|
||||
lossD.backward()
|
||||
optimizerD.step()
|
||||
|
||||
### Train Generator: max log(D(G(z)))
|
||||
netG.zero_grad()
|
||||
label = torch.ones(batch_size).to(device)
|
||||
output = netD(fake).reshape(-1)
|
||||
lossG = criterion(output, label)
|
||||
lossG.backward()
|
||||
optimizerG.step()
|
||||
|
||||
# Print losses ocassionally and print to tensorboard
|
||||
if batch_idx % 100 == 0:
|
||||
step += 1
|
||||
print(
|
||||
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
|
||||
Loss D: {lossD:.4f}, loss G: {lossG:.4f} D(x): {D_x:.4f}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
fake = netG(fixed_noise)
|
||||
img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
|
||||
writer_real.add_image(
|
||||
"Mnist Real Images", img_grid_real, global_step=step
|
||||
)
|
||||
writer_fake.add_image(
|
||||
"Mnist Fake Images", img_grid_fake, global_step=step
|
||||
)
|
||||
@@ -1,4 +0,0 @@
|
||||
### Generative Adversarial Network
|
||||
|
||||
DCGAN_mnist.py: main file and training network
|
||||
model_utils.py: Generator and discriminator implementation
|
||||
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
Discriminator and Generator implementation from DCGAN paper
|
||||
that we import in the main (DCGAN_mnist.py) file.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, channels_img, features_d):
|
||||
super(Discriminator, self).__init__()
|
||||
self.net = nn.Sequential(
|
||||
# N x channels_img x 64 x 64
|
||||
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
# N x features_d x 32 x 32
|
||||
nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(features_d * 2),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Conv2d(
|
||||
features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(features_d * 4),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Conv2d(
|
||||
features_d * 4, features_d * 8, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(features_d * 8),
|
||||
nn.LeakyReLU(0.2),
|
||||
# N x features_d*8 x 4 x 4
|
||||
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
|
||||
# N x 1 x 1 x 1
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, channels_noise, channels_img, features_g):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.net = nn.Sequential(
|
||||
# N x channels_noise x 1 x 1
|
||||
nn.ConvTranspose2d(
|
||||
channels_noise, features_g * 16, kernel_size=4, stride=1, padding=0
|
||||
),
|
||||
nn.BatchNorm2d(features_g * 16),
|
||||
nn.ReLU(),
|
||||
# N x features_g*16 x 4 x 4
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 16, features_g * 8, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(features_g * 8),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(features_g * 4),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(features_g * 2),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
# N x channels_img x 64 x 64
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
Reference in New Issue
Block a user