2021-01-30 21:49:15 +01:00
|
|
|
"""
|
|
|
|
|
Discriminator and Generator implementation from DCGAN paper
|
2022-12-21 14:03:08 +01:00
|
|
|
|
|
|
|
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
|
|
|
|
* 2020-11-01: Initial coding
|
|
|
|
|
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
2021-01-30 21:49:15 +01:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
|
|
|
def __init__(self, channels_img, features_d):
|
|
|
|
|
super(Discriminator, self).__init__()
|
|
|
|
|
self.disc = nn.Sequential(
|
|
|
|
|
# input: N x channels_img x 64 x 64
|
|
|
|
|
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
|
|
|
|
|
nn.LeakyReLU(0.2),
|
|
|
|
|
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
|
|
self._block(features_d, features_d * 2, 4, 2, 1),
|
|
|
|
|
self._block(features_d * 2, features_d * 4, 4, 2, 1),
|
|
|
|
|
self._block(features_d * 4, features_d * 8, 4, 2, 1),
|
|
|
|
|
# After all _block img output is 4x4 (Conv2d below makes into 1x1)
|
|
|
|
|
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Conv2d(
|
2022-12-21 14:03:08 +01:00
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size,
|
|
|
|
|
stride,
|
|
|
|
|
padding,
|
|
|
|
|
bias=False,
|
2021-01-30 21:49:15 +01:00
|
|
|
),
|
|
|
|
|
nn.InstanceNorm2d(out_channels, affine=True),
|
|
|
|
|
nn.LeakyReLU(0.2),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.disc(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Generator(nn.Module):
|
|
|
|
|
def __init__(self, channels_noise, channels_img, features_g):
|
|
|
|
|
super(Generator, self).__init__()
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
# Input: N x channels_noise x 1 x 1
|
|
|
|
|
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
|
|
|
|
|
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
|
|
|
|
|
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
|
|
|
|
|
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
|
|
|
|
|
nn.ConvTranspose2d(
|
|
|
|
|
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
|
|
|
|
|
),
|
|
|
|
|
# Output: N x channels_img x 64 x 64
|
|
|
|
|
nn.Tanh(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.ConvTranspose2d(
|
2022-12-21 14:03:08 +01:00
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size,
|
|
|
|
|
stride,
|
|
|
|
|
padding,
|
|
|
|
|
bias=False,
|
2021-01-30 21:49:15 +01:00
|
|
|
),
|
|
|
|
|
nn.BatchNorm2d(out_channels),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_weights(model):
|
|
|
|
|
# Initializes weights according to the DCGAN paper
|
|
|
|
|
for m in model.modules():
|
|
|
|
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
|
|
|
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test():
|
|
|
|
|
N, in_channels, H, W = 8, 3, 64, 64
|
|
|
|
|
noise_dim = 100
|
|
|
|
|
x = torch.randn((N, in_channels, H, W))
|
|
|
|
|
disc = Discriminator(in_channels, 8)
|
|
|
|
|
assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
|
|
|
|
|
gen = Generator(noise_dim, in_channels, 8)
|
|
|
|
|
z = torch.randn((N, noise_dim, 1, 1))
|
|
|
|
|
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# test()
|