update progan

This commit is contained in:
Aladdin Persson
2021-03-24 21:12:13 +01:00
7 changed files with 24 additions and 28 deletions

View File

@@ -51,7 +51,7 @@ class GoogLeNet(nn.Module):
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
self.dropout = nn.Dropout(p=0.4)
self.fc1 = nn.Linear(1024, 1000)
self.fc1 = nn.Linear(1024, num_classes)
if self.aux_logits:
self.aux1 = InceptionAux(512, num_classes)

View File

@@ -23,7 +23,7 @@ class block(nn.Module):
super(block, self).__init__()
self.expansion = 4
self.conv1 = nn.Conv2d(
in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0
in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0, bias=False
)
self.bn1 = nn.BatchNorm2d(intermediate_channels)
self.conv2 = nn.Conv2d(
@@ -32,6 +32,7 @@ class block(nn.Module):
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(intermediate_channels)
self.conv3 = nn.Conv2d(
@@ -40,6 +41,7 @@ class block(nn.Module):
kernel_size=1,
stride=1,
padding=0,
bias=False
)
self.bn3 = nn.BatchNorm2d(intermediate_channels * self.expansion)
self.relu = nn.ReLU()
@@ -70,7 +72,7 @@ class ResNet(nn.Module):
def __init__(self, block, layers, image_channels, num_classes):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -122,6 +124,7 @@ class ResNet(nn.Module):
intermediate_channels * 4,
kernel_size=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(intermediate_channels * 4),
)

View File

@@ -1,20 +1,25 @@
# ProGAN
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is pretty close to the original paper (I'll include some examples results below) but because of time limitation I only trained to 256x256 and on lower model size than they did in the paper. Making the number of channels to 512 instead of 256 as I trained it would probably make the results even better :)
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
## Results
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.
The model was trained on the Maps dataset and for fun I also tried using it to colorize anime.
|First is some more cherrypicked examples and second is just sampled from random latent vectors|
||
|:---:|
|![](results/result1.png)|
|![](results/64_examples.png)|
|![](results/result1.png)|
### Celeb-HQ dataset
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/lamsimon/celebahq).
### Download pretrained weights
<<<<<<< HEAD
Download pretrained weights [here](https://github.com/aladdinpersson/Machine-Learning-Collection/releases/download/1.0/ProGAN_weights.zip).
=======
Pretrained weights [here]().
>>>>>>> 1b761f144345e2803d61a5369e70cbe21da046e2
Extract the zip file and put the pth.tar files in the directory with all the python files. Make sure you put LOAD_MODEL=True in the config.py file.

View File

@@ -2,7 +2,7 @@ import cv2
import torch
from math import log2
START_TRAIN_AT_IMG_SIZE = 4
START_TRAIN_AT_IMG_SIZE = 128
DATASET = 'celeb_dataset'
CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic.pth"

View File

@@ -134,7 +134,7 @@ class Generator(nn.Module):
class Discriminator(nn.Module):
def __init__(self, in_channels, img_channels=3):
def __init__(self, z_dim, in_channels, img_channels=3):
super(Discriminator, self).__init__()
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
self.leaky = nn.LeakyReLU(0.2)

View File

@@ -60,24 +60,19 @@ def train_fn(
scaler_critic,
):
loop = tqdm(loader, leave=True)
# critic_losses = []
reals = 0
fakes = 0
for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE)
cur_batch_size = real.shape[0]
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
reals += critic_real.mean().item()
fakes += critic_fake.mean().item()
gp = gradient_penalty(critic, real, fake, device=config.DEVICE)
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
@@ -119,8 +114,6 @@ def train_fn(
tensorboard_step += 1
loop.set_postfix(
reals=reals / (batch_idx + 1),
fakes=fakes / (batch_idx + 1),
gp=gp.item(),
loss_critic=loss_critic.item(),
)
@@ -131,11 +124,12 @@ def train_fn(
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
# but really who cares..
gen = Generator(
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
# initialize optimizers and scalers for FP16 training
@@ -147,7 +141,7 @@ def main():
scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan")
writer = SummaryWriter(f"logs/gan1")
if config.LOAD_MODEL:
load_checkpoint(
@@ -163,10 +157,6 @@ def main():
tensorboard_step = 0
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
generate_examples(gen, step)
import sys
sys.exit()
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
@@ -197,4 +187,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@@ -87,6 +87,4 @@ def generate_examples(gen, steps, truncation=0.7, n=100):
noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
img = gen(noise, alpha, steps)
save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
gen.train()
gen.train()