checked GAN code

This commit is contained in:
Aladdin Persson
2022-12-21 14:03:08 +01:00
parent b6985eccc9
commit c646ef65e2
14 changed files with 225 additions and 270 deletions

View File

@@ -1,5 +1,9 @@
"""
Training of DCGAN network with WGAN loss
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
"""
import torch
@@ -9,6 +13,7 @@ import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights
@@ -61,7 +66,7 @@ critic.train()
for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised
for batch_idx, (data, _) in enumerate(loader):
for batch_idx, (data, _) in enumerate(tqdm(loader)):
data = data.to(device)
cur_batch_size = data.shape[0]
@@ -111,4 +116,4 @@ for epoch in range(NUM_EPOCHS):
step += 1
gen.train()
critic.train()
critic.train()