checked GAN code

This commit is contained in:
Aladdin Persson
2022-12-21 14:03:08 +01:00
parent b6985eccc9
commit c646ef65e2
14 changed files with 225 additions and 270 deletions

View File

@@ -1,131 +0,0 @@
"""
Example code of how to code GANs and more specifically DCGAN,
for more information about DCGANs read: https://arxiv.org/abs/1511.06434
We then train the DCGAN on the MNIST dataset (toy dataset of handwritten digits)
and then generate our own. You can apply this more generally on really any dataset
but MNIST is simple enough to get the overall idea.
Video explanation: https://youtu.be/5RYETbFFQ7s
Got any questions leave a comment on youtube :)
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-04-20 Initial coding
"""
# Imports
import torch
import torchvision
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
from torch.utils.data import (
DataLoader,
) # Gives easier dataset managment and creates mini batches
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
from model_utils import (
Discriminator,
Generator,
) # Import our models we've defined (from DCGAN paper)
# Hyperparameters
lr = 0.0005
batch_size = 64
image_size = 64
channels_img = 1
channels_noise = 256
num_epochs = 10
# For how many channels Generator and Discriminator should use
features_d = 16
features_g = 16
my_transforms = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
dataset = datasets.MNIST(
root="dataset/", train=True, transform=my_transforms, download=True
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create discriminator and generator
netD = Discriminator(channels_img, features_d).to(device)
netG = Generator(channels_noise, channels_img, features_g).to(device)
# Setup Optimizer for G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
netG.train()
netD.train()
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
fixed_noise = torch.randn(64, channels_noise, 1, 1).to(device)
writer_real = SummaryWriter(f"runs/GAN_MNIST/test_real")
writer_fake = SummaryWriter(f"runs/GAN_MNIST/test_fake")
step = 0
print("Starting Training...")
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(dataloader):
data = data.to(device)
batch_size = data.shape[0]
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
netD.zero_grad()
label = (torch.ones(batch_size) * 0.9).to(device)
output = netD(data).reshape(-1)
lossD_real = criterion(output, label)
D_x = output.mean().item()
noise = torch.randn(batch_size, channels_noise, 1, 1).to(device)
fake = netG(noise)
label = (torch.ones(batch_size) * 0.1).to(device)
output = netD(fake.detach()).reshape(-1)
lossD_fake = criterion(output, label)
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()
### Train Generator: max log(D(G(z)))
netG.zero_grad()
label = torch.ones(batch_size).to(device)
output = netD(fake).reshape(-1)
lossG = criterion(output, label)
lossG.backward()
optimizerG.step()
# Print losses ocassionally and print to tensorboard
if batch_idx % 100 == 0:
step += 1
print(
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
Loss D: {lossD:.4f}, loss G: {lossG:.4f} D(x): {D_x:.4f}"
)
with torch.no_grad():
fake = netG(fixed_noise)
img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
writer_fake.add_image(
"Mnist Fake Images", img_grid_fake, global_step=step
)

View File

@@ -1,4 +0,0 @@
### Generative Adversarial Network
DCGAN_mnist.py: main file and training network
model_utils.py: Generator and discriminator implementation

View File

@@ -1,76 +0,0 @@
"""
Discriminator and Generator implementation from DCGAN paper
that we import in the main (DCGAN_mnist.py) file.
"""
import torch
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self, channels_img, features_d):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# N x channels_img x 64 x 64
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
# N x features_d x 32 x 32
nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(features_d * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(
features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_d * 4),
nn.LeakyReLU(0.2),
nn.Conv2d(
features_d * 4, features_d * 8, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_d * 8),
nn.LeakyReLU(0.2),
# N x features_d*8 x 4 x 4
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
# N x 1 x 1 x 1
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
class Generator(nn.Module):
def __init__(self, channels_noise, channels_img, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
# N x channels_noise x 1 x 1
nn.ConvTranspose2d(
channels_noise, features_g * 16, kernel_size=4, stride=1, padding=0
),
nn.BatchNorm2d(features_g * 16),
nn.ReLU(),
# N x features_g*16 x 4 x 4
nn.ConvTranspose2d(
features_g * 16, features_g * 8, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(),
nn.ConvTranspose2d(
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
),
# N x channels_img x 64 x 64
nn.Tanh(),
)
def forward(self, x):
return self.net(x)