mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
checked GAN code
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
Discriminator and Generator implementation from DCGAN paper
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -11,9 +15,7 @@ class Discriminator(nn.Module):
|
||||
super(Discriminator, self).__init__()
|
||||
self.disc = nn.Sequential(
|
||||
# input: N x channels_img x 64 x 64
|
||||
nn.Conv2d(
|
||||
channels_img, features_d, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self._block(features_d, features_d * 2, 4, 2, 1),
|
||||
@@ -34,7 +36,7 @@ class Discriminator(nn.Module):
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
#nn.BatchNorm2d(out_channels),
|
||||
# nn.BatchNorm2d(out_channels),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
@@ -68,7 +70,7 @@ class Generator(nn.Module):
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
#nn.BatchNorm2d(out_channels),
|
||||
# nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
@@ -82,6 +84,7 @@ def initialize_weights(model):
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
|
||||
|
||||
def test():
|
||||
N, in_channels, H, W = 8, 3, 64, 64
|
||||
noise_dim = 100
|
||||
@@ -91,6 +94,8 @@ def test():
|
||||
gen = Generator(noise_dim, in_channels, 8)
|
||||
z = torch.randn((N, noise_dim, 1, 1))
|
||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||
print("Success, tests passed!")
|
||||
|
||||
|
||||
# test()
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
Reference in New Issue
Block a user