checked GAN code

This commit is contained in:
Aladdin Persson
2022-12-21 14:03:08 +01:00
parent b6985eccc9
commit c646ef65e2
14 changed files with 225 additions and 270 deletions

View File

@@ -1,3 +1,12 @@
"""
Simple GAN using fully connected layers
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
import torch.nn as nn
import torch.optim as optim
@@ -48,7 +57,10 @@ disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
@@ -104,4 +116,4 @@ for epoch in range(num_epochs):
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
step += 1
step += 1

View File

@@ -1,5 +1,9 @@
"""
Discriminator and Generator implementation from DCGAN paper
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -11,9 +15,7 @@ class Discriminator(nn.Module):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
# input: N x channels_img x 64 x 64
nn.Conv2d(
channels_img, features_d, kernel_size=4, stride=2, padding=1
),
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
# _block(in_channels, out_channels, kernel_size, stride, padding)
self._block(features_d, features_d * 2, 4, 2, 1),
@@ -34,7 +36,7 @@ class Discriminator(nn.Module):
padding,
bias=False,
),
#nn.BatchNorm2d(out_channels),
# nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2),
)
@@ -68,7 +70,7 @@ class Generator(nn.Module):
padding,
bias=False,
),
#nn.BatchNorm2d(out_channels),
# nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
@@ -82,6 +84,7 @@ def initialize_weights(model):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
def test():
N, in_channels, H, W = 8, 3, 64, 64
noise_dim = 100
@@ -91,6 +94,8 @@ def test():
gen = Generator(noise_dim, in_channels, 8)
z = torch.randn((N, noise_dim, 1, 1))
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
print("Success, tests passed!")
# test()
if __name__ == "__main__":
test()

View File

@@ -1,6 +1,10 @@
"""
Training of DCGAN network on MNIST dataset with Discriminator
and Generator imported from models.py
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -35,11 +39,12 @@ transforms = transforms.Compose(
)
# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
download=True)
dataset = datasets.MNIST(
root="dataset/", train=True, transform=transforms, download=True
)
# comment mnist above and uncomment below if train on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
@@ -92,14 +97,10 @@ for epoch in range(NUM_EPOCHS):
with torch.no_grad():
fake = gen(fixed_noise)
# take out (up to) 32 examples
img_grid_real = torchvision.utils.make_grid(
real[:32], normalize=True
)
img_grid_fake = torchvision.utils.make_grid(
fake[:32], normalize=True
)
img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1
step += 1

View File

@@ -2,6 +2,10 @@
Discriminator and Generator implementation from DCGAN paper,
with removed Sigmoid() as output from Discriminator (and therefor
it should be called critic)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -93,6 +97,7 @@ def test():
gen = Generator(noise_dim, in_channels, 8)
z = torch.randn((N, noise_dim, 1, 1))
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
print("Success, tests passed!")
# test()
if __name__ == "__main__":
test()

View File

@@ -1,5 +1,9 @@
"""
Training of DCGAN network with WGAN loss
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -9,6 +13,7 @@ import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights
@@ -61,7 +66,7 @@ critic.train()
for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised
for batch_idx, (data, _) in enumerate(loader):
for batch_idx, (data, _) in enumerate(tqdm(loader)):
data = data.to(device)
cur_batch_size = data.shape[0]
@@ -111,4 +116,4 @@ for epoch in range(NUM_EPOCHS):
step += 1
gen.train()
critic.train()
critic.train()

View File

@@ -1,5 +1,9 @@
"""
Discriminator and Generator implementation from DCGAN paper
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -24,7 +28,12 @@ class Discriminator(nn.Module):
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
),
nn.InstanceNorm2d(out_channels, affine=True),
nn.LeakyReLU(0.2),
@@ -53,7 +62,12 @@ class Generator(nn.Module):
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),

View File

@@ -1,5 +1,9 @@
"""
Training of WGAN-GP
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -10,6 +14,7 @@ import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import gradient_penalty, save_checkpoint, load_checkpoint
from model import Discriminator, Generator, initialize_weights
@@ -31,13 +36,14 @@ transforms = transforms.Compose(
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
),
]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
# comment mnist above and uncomment below for training on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
@@ -66,7 +72,7 @@ critic.train()
for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised
for batch_idx, (real, _) in enumerate(loader):
for batch_idx, (real, _) in enumerate(tqdm(loader)):
real = real.to(device)
cur_batch_size = real.shape[0]
@@ -108,4 +114,4 @@ for epoch in range(NUM_EPOCHS):
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1
step += 1

View File

@@ -11,7 +11,7 @@ LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 10
LOAD_MODEL = True
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
@@ -24,6 +24,6 @@ transforms = A.Compose(
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
],
additional_targets={"image0": "image"},
)
)

View File

@@ -1,11 +1,28 @@
"""
Discriminator model for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
import torch.nn as nn
class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
nn.Conv2d(
in_channels,
out_channels,
4,
stride,
1,
bias=True,
padding_mode="reflect",
),
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
@@ -32,15 +49,27 @@ class Discriminator(nn.Module):
layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
layers.append(
Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
)
in_channels = feature
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
layers.append(
nn.Conv2d(
in_channels,
1,
kernel_size=4,
stride=1,
padding=1,
padding_mode="reflect",
)
)
self.model = nn.Sequential(*layers)
def forward(self, x):
x = self.initial(x)
return torch.sigmoid(self.model(x))
def test():
x = torch.randn((5, 3, 256, 256))
model = Discriminator(in_channels=3)
@@ -50,4 +79,3 @@ def test():
if __name__ == "__main__":
test()

View File

@@ -1,6 +1,15 @@
"""
Generator model for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
super().__init__()
@@ -9,12 +18,13 @@ class ConvBlock(nn.Module):
if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity()
nn.ReLU(inplace=True) if use_act else nn.Identity(),
)
def forward(self, x):
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
@@ -26,31 +36,70 @@ class ResidualBlock(nn.Module):
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, img_channels, num_features = 64, num_residuals=9):
def __init__(self, img_channels, num_features=64, num_residuals=9):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.Conv2d(
img_channels,
num_features,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True),
)
self.down_blocks = nn.ModuleList(
[
ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
ConvBlock(
num_features, num_features * 2, kernel_size=3, stride=2, padding=1
),
ConvBlock(
num_features * 2,
num_features * 4,
kernel_size=3,
stride=2,
padding=1,
),
]
)
self.res_blocks = nn.Sequential(
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
*[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
)
self.up_blocks = nn.ModuleList(
[
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvBlock(
num_features * 4,
num_features * 2,
down=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
ConvBlock(
num_features * 2,
num_features * 1,
down=False,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
]
)
self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
self.last = nn.Conv2d(
num_features * 1,
img_channels,
kernel_size=7,
stride=1,
padding=3,
padding_mode="reflect",
)
def forward(self, x):
x = self.initial(x)
@@ -61,6 +110,7 @@ class Generator(nn.Module):
x = layer(x)
return torch.tanh(self.last(x))
def test():
img_channels = 3
img_size = 256
@@ -68,5 +118,6 @@ def test():
gen = Generator(img_channels, 9)
print(gen(x).shape)
if __name__ == "__main__":
test()

View File

@@ -1,3 +1,11 @@
"""
Training for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
from dataset import HorseZebraDataset
import sys
@@ -11,7 +19,10 @@ from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
def train_fn(
disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
H_reals = 0
H_fakes = 0
loop = tqdm(loader, leave=True)
@@ -39,7 +50,7 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
# put it togethor
D_loss = (D_H_loss + D_Z_loss)/2
D_loss = (D_H_loss + D_Z_loss) / 2
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
@@ -82,11 +93,10 @@ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d
g_scaler.update()
if idx % 200 == 0:
save_image(fake_horse*0.5+0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
def main():
@@ -111,23 +121,39 @@ def main():
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
config.CHECKPOINT_GEN_H,
gen_H,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
config.CHECKPOINT_GEN_Z,
gen_Z,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
config.CHECKPOINT_CRITIC_H,
disc_H,
opt_disc,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
config.CHECKPOINT_CRITIC_Z,
disc_Z,
opt_disc,
config.LEARNING_RATE,
)
dataset = HorseZebraDataset(
root_horse=config.TRAIN_DIR+"/horses", root_zebra=config.TRAIN_DIR+"/zebras", transform=config.transforms
root_horse=config.TRAIN_DIR + "/horses",
root_zebra=config.TRAIN_DIR + "/zebras",
transform=config.transforms,
)
val_dataset = HorseZebraDataset(
root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=config.transforms
root_horse="cyclegan_test/horse1",
root_zebra="cyclegan_test/zebra1",
transform=config.transforms,
)
val_loader = DataLoader(
val_dataset,
@@ -140,13 +166,25 @@ def main():
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True
pin_memory=True,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(config.NUM_EPOCHS):
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
train_fn(
disc_H,
disc_Z,
gen_Z,
gen_H,
loader,
opt_disc,
opt_gen,
L1,
mse,
d_scaler,
g_scaler,
)
if config.SAVE_MODEL:
save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
@@ -154,5 +192,6 @@ def main():
save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)
if __name__ == "__main__":
main()
main()

View File

@@ -1,131 +0,0 @@
"""
Example code of how to code GANs and more specifically DCGAN,
for more information about DCGANs read: https://arxiv.org/abs/1511.06434
We then train the DCGAN on the MNIST dataset (toy dataset of handwritten digits)
and then generate our own. You can apply this more generally on really any dataset
but MNIST is simple enough to get the overall idea.
Video explanation: https://youtu.be/5RYETbFFQ7s
Got any questions leave a comment on youtube :)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-04-20 Initial coding
"""
# Imports
import torch
import torchvision
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
from torch.utils.data import (
DataLoader,
) # Gives easier dataset managment and creates mini batches
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
from model_utils import (
Discriminator,
Generator,
) # Import our models we've defined (from DCGAN paper)
# Hyperparameters
lr = 0.0005
batch_size = 64
image_size = 64
channels_img = 1
channels_noise = 256
num_epochs = 10
# For how many channels Generator and Discriminator should use
features_d = 16
features_g = 16
my_transforms = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
dataset = datasets.MNIST(
root="dataset/", train=True, transform=my_transforms, download=True
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create discriminator and generator
netD = Discriminator(channels_img, features_d).to(device)
netG = Generator(channels_noise, channels_img, features_g).to(device)
# Setup Optimizer for G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
netG.train()
netD.train()
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
fixed_noise = torch.randn(64, channels_noise, 1, 1).to(device)
writer_real = SummaryWriter(f"runs/GAN_MNIST/test_real")
writer_fake = SummaryWriter(f"runs/GAN_MNIST/test_fake")
step = 0
print("Starting Training...")
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(dataloader):
data = data.to(device)
batch_size = data.shape[0]
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
netD.zero_grad()
label = (torch.ones(batch_size) * 0.9).to(device)
output = netD(data).reshape(-1)
lossD_real = criterion(output, label)
D_x = output.mean().item()
noise = torch.randn(batch_size, channels_noise, 1, 1).to(device)
fake = netG(noise)
label = (torch.ones(batch_size) * 0.1).to(device)
output = netD(fake.detach()).reshape(-1)
lossD_fake = criterion(output, label)
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()
### Train Generator: max log(D(G(z)))
netG.zero_grad()
label = torch.ones(batch_size).to(device)
output = netD(fake).reshape(-1)
lossG = criterion(output, label)
lossG.backward()
optimizerG.step()
# Print losses ocassionally and print to tensorboard
if batch_idx % 100 == 0:
step += 1
print(
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
Loss D: {lossD:.4f}, loss G: {lossG:.4f} D(x): {D_x:.4f}"
)
with torch.no_grad():
fake = netG(fixed_noise)
img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
writer_fake.add_image(
"Mnist Fake Images", img_grid_fake, global_step=step
)

View File

@@ -1,4 +0,0 @@
### Generative Adversarial Network
DCGAN_mnist.py: main file and training network
model_utils.py: Generator and discriminator implementation

View File

@@ -1,76 +0,0 @@
"""
Discriminator and Generator implementation from DCGAN paper
that we import in the main (DCGAN_mnist.py) file.
"""
import torch
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self, channels_img, features_d):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# N x channels_img x 64 x 64
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
# N x features_d x 32 x 32
nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(features_d * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(
features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_d * 4),
nn.LeakyReLU(0.2),
nn.Conv2d(
features_d * 4, features_d * 8, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_d * 8),
nn.LeakyReLU(0.2),
# N x features_d*8 x 4 x 4
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
# N x 1 x 1 x 1
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
class Generator(nn.Module):
def __init__(self, channels_noise, channels_img, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
# N x channels_noise x 1 x 1
nn.ConvTranspose2d(
channels_noise, features_g * 16, kernel_size=4, stride=1, padding=0
),
nn.BatchNorm2d(features_g * 16),
nn.ReLU(),
# N x features_g*16 x 4 x 4
nn.ConvTranspose2d(
features_g * 16, features_g * 8, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
),
# N x channels_img x 64 x 64
nn.Tanh(),
)
def forward(self, x):
return self.net(x)