mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
Initial commit
This commit is contained in:
107
ML/Pytorch/GANs/1. SimpleGAN/fc_gan.py
Normal file
107
ML/Pytorch/GANs/1. SimpleGAN/fc_gan.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_features):
|
||||
super().__init__()
|
||||
self.disc = nn.Sequential(
|
||||
nn.Linear(in_features, 128),
|
||||
nn.LeakyReLU(0.01),
|
||||
nn.Linear(128, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.disc(x)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, z_dim, img_dim):
|
||||
super().__init__()
|
||||
self.gen = nn.Sequential(
|
||||
nn.Linear(z_dim, 256),
|
||||
nn.LeakyReLU(0.01),
|
||||
nn.Linear(256, img_dim),
|
||||
nn.Tanh(), # normalize inputs to [-1, 1] so make outputs [-1, 1]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.gen(x)
|
||||
|
||||
|
||||
# Hyperparameters etc.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
lr = 3e-4
|
||||
z_dim = 64
|
||||
image_dim = 28 * 28 * 1 # 784
|
||||
batch_size = 32
|
||||
num_epochs = 50
|
||||
|
||||
disc = Discriminator(image_dim).to(device)
|
||||
gen = Generator(z_dim, image_dim).to(device)
|
||||
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
|
||||
transforms = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
|
||||
)
|
||||
|
||||
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
||||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
opt_disc = optim.Adam(disc.parameters(), lr=lr)
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=lr)
|
||||
criterion = nn.BCELoss()
|
||||
writer_fake = SummaryWriter(f"logs/fake")
|
||||
writer_real = SummaryWriter(f"logs/real")
|
||||
step = 0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, (real, _) in enumerate(loader):
|
||||
real = real.view(-1, 784).to(device)
|
||||
batch_size = real.shape[0]
|
||||
|
||||
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
|
||||
noise = torch.randn(batch_size, z_dim).to(device)
|
||||
fake = gen(noise)
|
||||
disc_real = disc(real).view(-1)
|
||||
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
|
||||
disc_fake = disc(fake).view(-1)
|
||||
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
|
||||
lossD = (lossD_real + lossD_fake) / 2
|
||||
disc.zero_grad()
|
||||
lossD.backward(retain_graph=True)
|
||||
opt_disc.step()
|
||||
|
||||
### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
|
||||
# where the second option of maximizing doesn't suffer from
|
||||
# saturating gradients
|
||||
output = disc(fake).view(-1)
|
||||
lossG = criterion(output, torch.ones_like(output))
|
||||
gen.zero_grad()
|
||||
lossG.backward()
|
||||
opt_gen.step()
|
||||
|
||||
if batch_idx == 0:
|
||||
print(
|
||||
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
|
||||
Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
|
||||
data = real.reshape(-1, 1, 28, 28)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
|
||||
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
|
||||
|
||||
writer_fake.add_image(
|
||||
"Mnist Fake Images", img_grid_fake, global_step=step
|
||||
)
|
||||
writer_real.add_image(
|
||||
"Mnist Real Images", img_grid_real, global_step=step
|
||||
)
|
||||
step += 1
|
||||
96
ML/Pytorch/GANs/2. DCGAN/model.py
Normal file
96
ML/Pytorch/GANs/2. DCGAN/model.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Discriminator and Generator implementation from DCGAN paper
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, channels_img, features_d):
|
||||
super(Discriminator, self).__init__()
|
||||
self.disc = nn.Sequential(
|
||||
# input: N x channels_img x 64 x 64
|
||||
nn.Conv2d(
|
||||
channels_img, features_d, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self._block(features_d, features_d * 2, 4, 2, 1),
|
||||
self._block(features_d * 2, features_d * 4, 4, 2, 1),
|
||||
self._block(features_d * 4, features_d * 8, 4, 2, 1),
|
||||
# After all _block img output is 4x4 (Conv2d below makes into 1x1)
|
||||
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
#nn.BatchNorm2d(out_channels),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.disc(x)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, channels_noise, channels_img, features_g):
|
||||
super(Generator, self).__init__()
|
||||
self.net = nn.Sequential(
|
||||
# Input: N x channels_noise x 1 x 1
|
||||
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
|
||||
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
|
||||
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
|
||||
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
# Output: N x channels_img x 64 x 64
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
#nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def initialize_weights(model):
|
||||
# Initializes weights according to the DCGAN paper
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
|
||||
def test():
|
||||
N, in_channels, H, W = 8, 3, 64, 64
|
||||
noise_dim = 100
|
||||
x = torch.randn((N, in_channels, H, W))
|
||||
disc = Discriminator(in_channels, 8)
|
||||
assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
|
||||
gen = Generator(noise_dim, in_channels, 8)
|
||||
z = torch.randn((N, noise_dim, 1, 1))
|
||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||
|
||||
|
||||
# test()
|
||||
105
ML/Pytorch/GANs/2. DCGAN/train.py
Normal file
105
ML/Pytorch/GANs/2. DCGAN/train.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Training of DCGAN network on MNIST dataset with Discriminator
|
||||
and Generator imported from models.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from model import Discriminator, Generator, initialize_weights
|
||||
|
||||
# Hyperparameters etc.
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
LEARNING_RATE = 2e-4 # could also use two lrs, one for gen and one for disc
|
||||
BATCH_SIZE = 128
|
||||
IMAGE_SIZE = 64
|
||||
CHANNELS_IMG = 1
|
||||
NOISE_DIM = 100
|
||||
NUM_EPOCHS = 5
|
||||
FEATURES_DISC = 64
|
||||
FEATURES_GEN = 64
|
||||
|
||||
transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(IMAGE_SIZE),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# If you train on MNIST, remember to set channels_img to 1
|
||||
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
|
||||
download=True)
|
||||
|
||||
# comment mnist above and uncomment below if train on CelebA
|
||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
|
||||
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
|
||||
initialize_weights(gen)
|
||||
initialize_weights(disc)
|
||||
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
|
||||
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
|
||||
criterion = nn.BCELoss()
|
||||
|
||||
fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
|
||||
writer_real = SummaryWriter(f"logs/real")
|
||||
writer_fake = SummaryWriter(f"logs/fake")
|
||||
step = 0
|
||||
|
||||
gen.train()
|
||||
disc.train()
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
# Target labels not needed! <3 unsupervised
|
||||
for batch_idx, (real, _) in enumerate(dataloader):
|
||||
real = real.to(device)
|
||||
noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
|
||||
fake = gen(noise)
|
||||
|
||||
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
|
||||
disc_real = disc(real).reshape(-1)
|
||||
loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
|
||||
disc_fake = disc(fake.detach()).reshape(-1)
|
||||
loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
|
||||
loss_disc = (loss_disc_real + loss_disc_fake) / 2
|
||||
disc.zero_grad()
|
||||
loss_disc.backward()
|
||||
opt_disc.step()
|
||||
|
||||
### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
|
||||
output = disc(fake).reshape(-1)
|
||||
loss_gen = criterion(output, torch.ones_like(output))
|
||||
gen.zero_grad()
|
||||
loss_gen.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
if batch_idx % 100 == 0:
|
||||
print(
|
||||
f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
|
||||
Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
fake = gen(fixed_noise)
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(
|
||||
real[:32], normalize=True
|
||||
)
|
||||
img_grid_fake = torchvision.utils.make_grid(
|
||||
fake[:32], normalize=True
|
||||
)
|
||||
|
||||
writer_real.add_image("Real", img_grid_real, global_step=step)
|
||||
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
||||
|
||||
step += 1
|
||||
98
ML/Pytorch/GANs/3. WGAN/model.py
Normal file
98
ML/Pytorch/GANs/3. WGAN/model.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Discriminator and Generator implementation from DCGAN paper,
|
||||
with removed Sigmoid() as output from Discriminator (and therefor
|
||||
it should be called critic)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, channels_img, features_d):
|
||||
super(Discriminator, self).__init__()
|
||||
self.disc = nn.Sequential(
|
||||
# input: N x channels_img x 64 x 64
|
||||
nn.Conv2d(
|
||||
channels_img, features_d, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self._block(features_d, features_d * 2, 4, 2, 1),
|
||||
self._block(features_d * 2, features_d * 4, 4, 2, 1),
|
||||
self._block(features_d * 4, features_d * 8, 4, 2, 1),
|
||||
# After all _block img output is 4x4 (Conv2d below makes into 1x1)
|
||||
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
|
||||
)
|
||||
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
nn.InstanceNorm2d(out_channels, affine=True),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.disc(x)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, channels_noise, channels_img, features_g):
|
||||
super(Generator, self).__init__()
|
||||
self.net = nn.Sequential(
|
||||
# Input: N x channels_noise x 1 x 1
|
||||
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
|
||||
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
|
||||
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
|
||||
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
# Output: N x channels_img x 64 x 64
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def initialize_weights(model):
|
||||
# Initializes weights according to the DCGAN paper
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
|
||||
|
||||
def test():
|
||||
N, in_channels, H, W = 8, 3, 64, 64
|
||||
noise_dim = 100
|
||||
x = torch.randn((N, in_channels, H, W))
|
||||
disc = Discriminator(in_channels, 8)
|
||||
assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
|
||||
gen = Generator(noise_dim, in_channels, 8)
|
||||
z = torch.randn((N, noise_dim, 1, 1))
|
||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||
|
||||
|
||||
# test()
|
||||
114
ML/Pytorch/GANs/3. WGAN/train.py
Normal file
114
ML/Pytorch/GANs/3. WGAN/train.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Training of DCGAN network with WGAN loss
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from model import Discriminator, Generator, initialize_weights
|
||||
|
||||
# Hyperparameters etc
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LEARNING_RATE = 5e-5
|
||||
BATCH_SIZE = 64
|
||||
IMAGE_SIZE = 64
|
||||
CHANNELS_IMG = 1
|
||||
Z_DIM = 128
|
||||
NUM_EPOCHS = 5
|
||||
FEATURES_CRITIC = 64
|
||||
FEATURES_GEN = 64
|
||||
CRITIC_ITERATIONS = 5
|
||||
WEIGHT_CLIP = 0.01
|
||||
|
||||
transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(IMAGE_SIZE),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
||||
#comment mnist and uncomment below if you want to train on CelebA dataset
|
||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
# initialize gen and disc/critic
|
||||
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
|
||||
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
|
||||
initialize_weights(gen)
|
||||
initialize_weights(critic)
|
||||
|
||||
# initializate optimizer
|
||||
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
|
||||
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
# for tensorboard plotting
|
||||
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
|
||||
writer_real = SummaryWriter(f"logs/real")
|
||||
writer_fake = SummaryWriter(f"logs/fake")
|
||||
step = 0
|
||||
|
||||
gen.train()
|
||||
critic.train()
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
# Target labels not needed! <3 unsupervised
|
||||
for batch_idx, (data, _) in enumerate(loader):
|
||||
data = data.to(device)
|
||||
cur_batch_size = data.shape[0]
|
||||
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)]
|
||||
for _ in range(CRITIC_ITERATIONS):
|
||||
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
|
||||
fake = gen(noise)
|
||||
critic_real = critic(data).reshape(-1)
|
||||
critic_fake = critic(fake).reshape(-1)
|
||||
loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
critic.zero_grad()
|
||||
loss_critic.backward(retain_graph=True)
|
||||
opt_critic.step()
|
||||
|
||||
# clip critic weights between -0.01, 0.01
|
||||
for p in critic.parameters():
|
||||
p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
|
||||
|
||||
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
|
||||
gen_fake = critic(fake).reshape(-1)
|
||||
loss_gen = -torch.mean(gen_fake)
|
||||
gen.zero_grad()
|
||||
loss_gen.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
if batch_idx % 100 == 0 and batch_idx > 0:
|
||||
gen.eval()
|
||||
critic.eval()
|
||||
print(
|
||||
f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
|
||||
Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
fake = gen(noise)
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(
|
||||
data[:32], normalize=True
|
||||
)
|
||||
img_grid_fake = torchvision.utils.make_grid(
|
||||
fake[:32], normalize=True
|
||||
)
|
||||
|
||||
writer_real.add_image("Real", img_grid_real, global_step=step)
|
||||
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
||||
|
||||
step += 1
|
||||
gen.train()
|
||||
critic.train()
|
||||
84
ML/Pytorch/GANs/4. WGAN-GP/model.py
Normal file
84
ML/Pytorch/GANs/4. WGAN-GP/model.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Discriminator and Generator implementation from DCGAN paper
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, channels_img, features_d):
|
||||
super(Discriminator, self).__init__()
|
||||
self.disc = nn.Sequential(
|
||||
# input: N x channels_img x 64 x 64
|
||||
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
# _block(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self._block(features_d, features_d * 2, 4, 2, 1),
|
||||
self._block(features_d * 2, features_d * 4, 4, 2, 1),
|
||||
self._block(features_d * 4, features_d * 8, 4, 2, 1),
|
||||
# After all _block img output is 4x4 (Conv2d below makes into 1x1)
|
||||
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
|
||||
)
|
||||
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias=False,
|
||||
),
|
||||
nn.InstanceNorm2d(out_channels, affine=True),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.disc(x)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, channels_noise, channels_img, features_g):
|
||||
super(Generator, self).__init__()
|
||||
self.net = nn.Sequential(
|
||||
# Input: N x channels_noise x 1 x 1
|
||||
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
|
||||
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
|
||||
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
|
||||
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
# Output: N x channels_img x 64 x 64
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def initialize_weights(model):
|
||||
# Initializes weights according to the DCGAN paper
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
|
||||
|
||||
def test():
|
||||
N, in_channels, H, W = 8, 3, 64, 64
|
||||
noise_dim = 100
|
||||
x = torch.randn((N, in_channels, H, W))
|
||||
disc = Discriminator(in_channels, 8)
|
||||
assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
|
||||
gen = Generator(noise_dim, in_channels, 8)
|
||||
z = torch.randn((N, noise_dim, 1, 1))
|
||||
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
|
||||
|
||||
|
||||
# test()
|
||||
111
ML/Pytorch/GANs/4. WGAN-GP/train.py
Normal file
111
ML/Pytorch/GANs/4. WGAN-GP/train.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Training of WGAN-GP
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import gradient_penalty, save_checkpoint, load_checkpoint
|
||||
from model import Discriminator, Generator, initialize_weights
|
||||
|
||||
# Hyperparameters etc.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LEARNING_RATE = 1e-4
|
||||
BATCH_SIZE = 64
|
||||
IMAGE_SIZE = 64
|
||||
CHANNELS_IMG = 1
|
||||
Z_DIM = 100
|
||||
NUM_EPOCHS = 100
|
||||
FEATURES_CRITIC = 16
|
||||
FEATURES_GEN = 16
|
||||
CRITIC_ITERATIONS = 5
|
||||
LAMBDA_GP = 10
|
||||
|
||||
transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(IMAGE_SIZE),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
|
||||
# comment mnist above and uncomment below for training on CelebA
|
||||
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# initialize gen and disc, note: discriminator should be called critic,
|
||||
# according to WGAN paper (since it no longer outputs between [0, 1])
|
||||
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
|
||||
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
|
||||
initialize_weights(gen)
|
||||
initialize_weights(critic)
|
||||
|
||||
# initializate optimizer
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
|
||||
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
|
||||
|
||||
# for tensorboard plotting
|
||||
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
|
||||
writer_real = SummaryWriter(f"logs/GAN_MNIST/real")
|
||||
writer_fake = SummaryWriter(f"logs/GAN_MNIST/fake")
|
||||
step = 0
|
||||
|
||||
gen.train()
|
||||
critic.train()
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
# Target labels not needed! <3 unsupervised
|
||||
for batch_idx, (real, _) in enumerate(loader):
|
||||
real = real.to(device)
|
||||
cur_batch_size = real.shape[0]
|
||||
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)]
|
||||
# equivalent to minimizing the negative of that
|
||||
for _ in range(CRITIC_ITERATIONS):
|
||||
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
|
||||
fake = gen(noise)
|
||||
critic_real = critic(real).reshape(-1)
|
||||
critic_fake = critic(fake).reshape(-1)
|
||||
gp = gradient_penalty(critic, real, fake, device=device)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
|
||||
)
|
||||
critic.zero_grad()
|
||||
loss_critic.backward(retain_graph=True)
|
||||
opt_critic.step()
|
||||
|
||||
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
|
||||
gen_fake = critic(fake).reshape(-1)
|
||||
loss_gen = -torch.mean(gen_fake)
|
||||
gen.zero_grad()
|
||||
loss_gen.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
if batch_idx % 100 == 0 and batch_idx > 0:
|
||||
print(
|
||||
f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
|
||||
Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
fake = gen(fixed_noise)
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
|
||||
|
||||
writer_real.add_image("Real", img_grid_real, global_step=step)
|
||||
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
|
||||
|
||||
step += 1
|
||||
35
ML/Pytorch/GANs/4. WGAN-GP/utils.py
Normal file
35
ML/Pytorch/GANs/4. WGAN-GP/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, device="cpu"):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * alpha + fake * (1 - alpha)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(state, filename="celeba_wgan_gp.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
torch.save(state, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint, gen, disc):
|
||||
print("=> Loading checkpoint")
|
||||
gen.load_state_dict(checkpoint['gen'])
|
||||
disc.load_state_dict(checkpoint['disc'])
|
||||
205
ML/Pytorch/GANs/5. ProGAN/model.py
Normal file
205
ML/Pytorch/GANs/5. ProGAN/model.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Implementation of ProGAN generator and discriminator with the key
|
||||
attributions from the paper. We have tried to make the implementation
|
||||
compact but a goal is also to keep it readable and understandable.
|
||||
Specifically the key points implemented are:
|
||||
|
||||
1) Progressive growing (of model and layers)
|
||||
2) Minibatch std on Discriminator
|
||||
3) Normalization with PixelNorm
|
||||
4) Equalized Learning Rate (here I cheated and only did it on Conv layers)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from math import log2
|
||||
|
||||
"""
|
||||
Factors is used in Discrmininator and Generator for how much
|
||||
the channels should be multiplied and expanded for each layer,
|
||||
so specifically the first 5 layers the channels stay the same,
|
||||
whereas when we increase the img_size (towards the later layers)
|
||||
we decrease the number of chanels by 1/2, 1/4, etc.
|
||||
"""
|
||||
factors = [1, 1, 1, 1, 1/2, 1/4, 1/4, 1/8, 1/16]
|
||||
|
||||
|
||||
class WSConv2d(nn.Module):
|
||||
"""
|
||||
Weight scaled Conv2d (Equalized Learning Rate)
|
||||
Note that input is multiplied rather than changing weights
|
||||
this will have the same result.
|
||||
|
||||
Inspired by:
|
||||
https://github.com/nvnbny/progressive_growing_of_gans/blob/master/modelUtils.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
|
||||
):
|
||||
super(WSConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding
|
||||
)
|
||||
self.scale = (gain / (self.conv.weight[0].numel())) ** 0.5
|
||||
|
||||
# initialize conv layer
|
||||
nn.init.normal_(self.conv.weight)
|
||||
nn.init.zeros_(self.conv.bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x * self.scale)
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.epsilon = 1e-8
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(
|
||||
torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon
|
||||
)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, use_pixelnorm=True):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.use_pn = use_pixelnorm
|
||||
self.conv1 = WSConv2d(in_channels, out_channels)
|
||||
self.conv2 = WSConv2d(out_channels, out_channels)
|
||||
self.leaky = nn.LeakyReLU(0.2)
|
||||
self.pn = PixelNorm()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.leaky(self.conv1(x))
|
||||
x = self.pn(x) if self.use_pn else x
|
||||
x = self.leaky(self.conv2(x))
|
||||
x = self.pn(x) if self.use_pn else x
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, z_dim, in_channels, img_size, img_channels=3):
|
||||
super(Generator, self).__init__()
|
||||
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
||||
|
||||
# initial takes 1x1 -> 4x4
|
||||
self.initial = nn.Sequential(
|
||||
nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
|
||||
nn.LeakyReLU(0.2),
|
||||
PixelNorm(),
|
||||
)
|
||||
|
||||
# Create progression blocks and rgb layers
|
||||
channels = in_channels
|
||||
|
||||
# we need to double img for log2(img_size/4) and
|
||||
# +1 in loop for initial 4x4
|
||||
for idx in range(int(log2(img_size/4)) + 1):
|
||||
conv_in = channels
|
||||
conv_out = int(in_channels*factors[idx])
|
||||
self.prog_blocks.append(ConvBlock(conv_in, conv_out))
|
||||
self.rgb_layers.append(WSConv2d(conv_out, img_channels, kernel_size=1, stride=1, padding=0))
|
||||
channels = conv_out
|
||||
|
||||
def fade_in(self, alpha, upscaled, generated):
|
||||
#assert 0 <= alpha <= 1, "Alpha not between 0 and 1"
|
||||
#assert upscaled.shape == generated.shape
|
||||
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
|
||||
|
||||
def forward(self, x, alpha, steps):
|
||||
upscaled = self.initial(x)
|
||||
out = self.prog_blocks[0](upscaled)
|
||||
|
||||
if steps == 0:
|
||||
return self.rgb_layers[0](out)
|
||||
|
||||
for step in range(1, steps+1):
|
||||
upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
|
||||
out = self.prog_blocks[step](upscaled)
|
||||
|
||||
# The number of channels in upscale will stay the same, while
|
||||
# out which has moved through prog_blocks might change. To ensure
|
||||
# we can convert both to rgb we use different rgb_layers
|
||||
# (steps-1) and steps for upscaled, out respectively
|
||||
final_upscaled = self.rgb_layers[steps - 1](upscaled)
|
||||
final_out = self.rgb_layers[steps](out)
|
||||
return self.fade_in(alpha, final_upscaled, final_out)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, img_size, z_dim, in_channels, img_channels=3):
|
||||
super(Discriminator, self).__init__()
|
||||
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
||||
|
||||
# Create progression blocks and rgb layers
|
||||
channels = in_channels
|
||||
for idx in range(int(log2(img_size/4)) + 1):
|
||||
conv_in = int(in_channels * factors[idx])
|
||||
conv_out = channels
|
||||
self.rgb_layers.append(WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0))
|
||||
self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
|
||||
channels = conv_in
|
||||
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
# +1 to in_channels because we concatenate from minibatch std
|
||||
self.conv = WSConv2d(in_channels + 1, z_dim, kernel_size=4, stride=1, padding=0)
|
||||
self.linear = nn.Linear(z_dim, 1)
|
||||
|
||||
def fade_in(self, alpha, downscaled, out):
|
||||
"""Used to fade in downscaled using avgpooling and output from CNN"""
|
||||
#assert 0 <= alpha <= 1, "Alpha needs to be between [0, 1]"
|
||||
#assert downscaled.shape == out.shape
|
||||
return alpha * out + (1 - alpha) * downscaled
|
||||
|
||||
def minibatch_std(self, x):
|
||||
batch_statistics = (
|
||||
torch.std(x, dim=0)
|
||||
.mean()
|
||||
.repeat(x.shape[0], 1, x.shape[2], x.shape[3])
|
||||
)
|
||||
return torch.cat([x, batch_statistics], dim=1)
|
||||
|
||||
def forward(self, x, alpha, steps):
|
||||
out = self.rgb_layers[steps](x) # convert from rgb as initial step
|
||||
|
||||
if steps == 0: # i.e, image is 4x4
|
||||
out = self.minibatch_std(out)
|
||||
out = self.conv(out)
|
||||
return self.linear(out.view(-1, out.shape[1]))
|
||||
|
||||
# index steps which has the "reverse" fade_in
|
||||
downscaled = self.rgb_layers[steps - 1](self.avg_pool(x))
|
||||
out = self.avg_pool(self.prog_blocks[steps](out))
|
||||
out = self.fade_in(alpha, downscaled, out)
|
||||
|
||||
for step in range(steps - 1, 0, -1):
|
||||
downscaled = self.avg_pool(out)
|
||||
out = self.prog_blocks[step](downscaled)
|
||||
|
||||
out = self.minibatch_std(out)
|
||||
out = self.conv(out)
|
||||
return self.linear(out.view(-1, out.shape[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
Z_DIM = 100
|
||||
IN_CHANNELS = 16
|
||||
img_size = 512
|
||||
num_steps = int(log2(img_size / 4))
|
||||
x = torch.randn((5, Z_DIM, 1, 1))
|
||||
gen = Generator(Z_DIM, IN_CHANNELS, img_size=img_size)
|
||||
disc = Discriminator(img_size, Z_DIM, IN_CHANNELS)
|
||||
start = time.time()
|
||||
with torch.autograd.profiler.profile(use_cuda=True) as prof:
|
||||
z = gen(x, alpha=0.5, steps=num_steps)
|
||||
print(prof)
|
||||
gen_time = time.time()-start
|
||||
t = time.time()
|
||||
out = disc(z, 0.01, num_steps)
|
||||
disc_time = time.time()-t
|
||||
print(gen_time, disc_time)
|
||||
#print(disc(z, 0.01, num_steps).shape)
|
||||
5
ML/Pytorch/GANs/5. ProGAN/test.py
Normal file
5
ML/Pytorch/GANs/5. ProGAN/test.py
Normal file
@@ -0,0 +1,5 @@
|
||||
def func(x=1, y=2, **kwargs):
|
||||
print(x, y)
|
||||
|
||||
|
||||
print(func(x=3, y=4))
|
||||
165
ML/Pytorch/GANs/5. ProGAN/train.py
Normal file
165
ML/Pytorch/GANs/5. ProGAN/train.py
Normal file
@@ -0,0 +1,165 @@
|
||||
""" Training of ProGAN using WGAN-GP loss"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
|
||||
from model import Discriminator, Generator
|
||||
from math import log2
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
torch.backends.cudnn.benchmarks = True
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Hyperparameters etc.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
LEARNING_RATE = 1e-4
|
||||
BATCH_SIZES = [128, 128, 64, 16, 8, 4, 2, 2, 1]
|
||||
IMAGE_SIZE = 128
|
||||
CHANNELS_IMG = 3
|
||||
Z_DIM = 128
|
||||
IN_CHANNELS = 128
|
||||
CRITIC_ITERATIONS = 1
|
||||
LAMBDA_GP = 10
|
||||
NUM_STEPS = int(log2(IMAGE_SIZE / 4)) + 1
|
||||
PROGRESSIVE_EPOCHS = [2 ** i for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
|
||||
PROGRESSIVE_EPOCHS = [8 for i in range(int(log2(IMAGE_SIZE / 4) + 1))]
|
||||
fixed_noise = torch.randn(8, Z_DIM, 1, 1).to(device)
|
||||
NUM_WORKERS = 4
|
||||
|
||||
def get_loader(image_size):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
[0.5 for _ in range(CHANNELS_IMG)],
|
||||
[0.5 for _ in range(CHANNELS_IMG)],
|
||||
),
|
||||
]
|
||||
)
|
||||
batch_size = BATCH_SIZES[int(log2(image_size/4))]
|
||||
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
|
||||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
|
||||
return loader, dataset
|
||||
|
||||
def train_fn(
|
||||
critic,
|
||||
gen,
|
||||
loader,
|
||||
dataset,
|
||||
step,
|
||||
alpha,
|
||||
opt_critic,
|
||||
opt_gen,
|
||||
tensorboard_step,
|
||||
writer,
|
||||
):
|
||||
start = time.time()
|
||||
total_time = 0
|
||||
training = tqdm(loader, leave=True)
|
||||
for batch_idx, (real, _) in enumerate(training):
|
||||
real = real.to(device)
|
||||
cur_batch_size = real.shape[0]
|
||||
model_start = time.time()
|
||||
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)]
|
||||
# which is equivalent to minimizing the negative of the expression
|
||||
for _ in range(CRITIC_ITERATIONS):
|
||||
critic.zero_grad()
|
||||
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
|
||||
fake = gen(noise, alpha, step)
|
||||
critic_real = critic(real, alpha, step).reshape(-1)
|
||||
critic_fake = critic(fake, alpha, step).reshape(-1)
|
||||
gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
+ LAMBDA_GP * gp
|
||||
)
|
||||
loss_critic.backward(retain_graph=True)
|
||||
opt_critic.step()
|
||||
|
||||
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
|
||||
gen.zero_grad()
|
||||
fake = gen(noise, alpha, step)
|
||||
gen_fake = critic(fake, alpha, step).reshape(-1)
|
||||
loss_gen = -torch.mean(gen_fake)
|
||||
loss_gen.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# Update alpha and ensure less than 1
|
||||
alpha += cur_batch_size / (
|
||||
(PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
|
||||
)
|
||||
alpha = min(alpha, 1)
|
||||
total_time += time.time()-model_start
|
||||
|
||||
if batch_idx % 300 == 0:
|
||||
with torch.no_grad():
|
||||
fixed_fakes = gen(fixed_noise, alpha, step)
|
||||
plot_to_tensorboard(
|
||||
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
|
||||
)
|
||||
tensorboard_step += 1
|
||||
|
||||
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
|
||||
return tensorboard_step, alpha
|
||||
|
||||
|
||||
def main():
|
||||
# initialize gen and disc, note: discriminator should be called critic,
|
||||
# according to WGAN paper (since it no longer outputs between [0, 1])
|
||||
gen = Generator(Z_DIM, IN_CHANNELS, img_size=IMAGE_SIZE, img_channels=CHANNELS_IMG).to(device)
|
||||
critic = Discriminator(IMAGE_SIZE, Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(device)
|
||||
|
||||
# initializate optimizer
|
||||
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
|
||||
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
|
||||
|
||||
# for tensorboard plotting
|
||||
writer = SummaryWriter(f"logs/gan")
|
||||
|
||||
load_checkpoint(torch.load("celeba_wgan_gp.pth.tar"), gen, critic)
|
||||
gen.train()
|
||||
critic.train()
|
||||
|
||||
tensorboard_step = 0
|
||||
for step, num_epochs in enumerate(PROGRESSIVE_EPOCHS):
|
||||
alpha = 0.01
|
||||
if step < 3:
|
||||
continue
|
||||
|
||||
if step == 4:
|
||||
print(f"Img size is: {4*2**step}")
|
||||
|
||||
loader, dataset = get_loader(4 * 2 ** step)
|
||||
for epoch in range(num_epochs):
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}]")
|
||||
tensorboard_step, alpha = train_fn(
|
||||
critic,
|
||||
gen,
|
||||
loader,
|
||||
dataset,
|
||||
step,
|
||||
alpha,
|
||||
opt_critic,
|
||||
opt_gen,
|
||||
tensorboard_step,
|
||||
writer,
|
||||
)
|
||||
|
||||
checkpoint = {'gen': gen.state_dict(),
|
||||
'critic': critic.state_dict(),
|
||||
'opt_gen': opt_gen.state_dict(),
|
||||
'opt_critic': opt_critic.state_dict()}
|
||||
|
||||
save_checkpoint(checkpoint)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
ML/Pytorch/GANs/5. ProGAN/utils.py
Normal file
54
ML/Pytorch/GANs/5. ProGAN/utils.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
|
||||
# Print losses occasionally and print to tensorboard
|
||||
def plot_to_tensorboard(
|
||||
writer, loss_critic, loss_gen, real, fake, tensorboard_step
|
||||
):
|
||||
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
|
||||
|
||||
with torch.no_grad():
|
||||
# take out (up to) 32 examples
|
||||
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
|
||||
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
|
||||
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
|
||||
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
|
||||
|
||||
|
||||
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
|
||||
BATCH_SIZE, C, H, W = real.shape
|
||||
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||||
interpolated_images = real * beta + fake * (1 - beta)
|
||||
|
||||
# Calculate critic scores
|
||||
mixed_scores = critic(interpolated_images, alpha, train_step)
|
||||
|
||||
# Take the gradient of the scores with respect to the images
|
||||
gradient = torch.autograd.grad(
|
||||
inputs=interpolated_images,
|
||||
outputs=mixed_scores,
|
||||
grad_outputs=torch.ones_like(mixed_scores),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
gradient = gradient.view(gradient.shape[0], -1)
|
||||
gradient_norm = gradient.norm(2, dim=1)
|
||||
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
def save_checkpoint(state, filename="celeba_wgan_gp.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
torch.save(state, filename)
|
||||
|
||||
def load_checkpoint(checkpoint, gen, disc, opt_gen=None, opt_disc=None):
|
||||
print("=> Loading checkpoint")
|
||||
gen.load_state_dict(checkpoint['gen'])
|
||||
disc.load_state_dict(checkpoint['critic'])
|
||||
|
||||
if opt_gen != None and opt_disc != None:
|
||||
opt_gen.load_state_dict(checkpoint['opt_gen'])
|
||||
opt_disc.load_state_dict(checkpoint['opt_critic'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user