mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
checked GAN code
This commit is contained in:
@@ -2,6 +2,10 @@
|
||||
Discriminator and Generator implementation from DCGAN paper,
|
||||
with removed Sigmoid() as output from Discriminator (and therefor
|
||||
it should be called critic)
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -93,6 +97,7 @@ def test():
|
||||
gen = Generator(noise_dim, in_channels, 8)
|
||||
z = torch.randn((N, noise_dim, 1, 1))
|
||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||
print("Success, tests passed!")
|
||||
|
||||
|
||||
# test()
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
Training of DCGAN network with WGAN loss
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -9,6 +13,7 @@ import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from model import Discriminator, Generator, initialize_weights
|
||||
|
||||
@@ -61,7 +66,7 @@ critic.train()
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
# Target labels not needed! <3 unsupervised
|
||||
for batch_idx, (data, _) in enumerate(loader):
|
||||
for batch_idx, (data, _) in enumerate(tqdm(loader)):
|
||||
data = data.to(device)
|
||||
cur_batch_size = data.shape[0]
|
||||
|
||||
@@ -111,4 +116,4 @@ for epoch in range(NUM_EPOCHS):
|
||||
|
||||
step += 1
|
||||
gen.train()
|
||||
critic.train()
|
||||
critic.train()
|
||||
Reference in New Issue
Block a user