mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
checked GAN code
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
Training of DCGAN network on MNIST dataset with Discriminator
|
||||
and Generator imported from models.py
|
||||
|
||||
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
||||
* 2020-11-01: Initial coding
|
||||
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -35,11 +39,12 @@ transforms = transforms.Compose(
|
||||
)
|
||||
|
||||
# If you train on MNIST, remember to set channels_img to 1
|
||||
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
|
||||
download=True)
|
||||
dataset = datasets.MNIST(
|
||||
root="dataset/", train=True, transform=transforms, download=True
|
||||
)
|
||||
|
||||
# comment mnist above and uncomment below if train on CelebA
|
||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
|
||||
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
|
||||
@@ -92,14 +97,10 @@ for epoch in range(NUM_EPOCHS):
|
||||
with torch.no_grad():
|
||||
fake = gen(fixed_noise)
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(
|
||||
real[:32], normalize=True
|
||||
)
|
||||
img_grid_fake = torchvision.utils.make_grid(
|
||||
fake[:32], normalize=True
|
||||
)
|
||||
img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
|
||||
|
||||
writer_real.add_image("Real", img_grid_real, global_step=step)
|
||||
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
||||
|
||||
step += 1
|
||||
step += 1
|
||||
|
||||
Reference in New Issue
Block a user