Files

119 lines
3.9 KiB
Python
Raw Permalink Normal View History

2021-01-30 21:49:15 +01:00
"""
Training of DCGAN network with WGAN loss
2022-12-21 14:03:08 +01:00
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-01: Initial coding
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
2021-01-30 21:49:15 +01:00
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
2022-12-21 14:03:08 +01:00
from tqdm import tqdm
2021-01-30 21:49:15 +01:00
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights
# Hyperparameters etc
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 128
NUM_EPOCHS = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01
transforms = transforms.Compose(
[
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
),
]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
#comment mnist and uncomment below if you want to train on CelebA dataset
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# initialize gen and disc/critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)
# initializate optimizer
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)
# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0
gen.train()
critic.train()
for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised
2022-12-21 14:03:08 +01:00
for batch_idx, (data, _) in enumerate(tqdm(loader)):
2021-01-30 21:49:15 +01:00
data = data.to(device)
cur_batch_size = data.shape[0]
# Train Critic: max E[critic(real)] - E[critic(fake)]
for _ in range(CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
fake = gen(noise)
critic_real = critic(data).reshape(-1)
critic_fake = critic(fake).reshape(-1)
loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
critic.zero_grad()
loss_critic.backward(retain_graph=True)
opt_critic.step()
# clip critic weights between -0.01, 0.01
for p in critic.parameters():
p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
gen_fake = critic(fake).reshape(-1)
loss_gen = -torch.mean(gen_fake)
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# Print losses occasionally and print to tensorboard
if batch_idx % 100 == 0 and batch_idx > 0:
gen.eval()
critic.eval()
print(
f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
)
with torch.no_grad():
fake = gen(noise)
# take out (up to) 32 examples
img_grid_real = torchvision.utils.make_grid(
data[:32], normalize=True
)
img_grid_fake = torchvision.utils.make_grid(
fake[:32], normalize=True
)
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1
gen.train()
2022-12-21 14:03:08 +01:00
critic.train()