stylegan, esrgan, srgan code

This commit is contained in:
Aladdin Persson
2021-05-15 14:58:41 +02:00
parent a2ee9271b5
commit 5033cbb567
34 changed files with 1569 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

View File

@@ -0,0 +1,48 @@
import torch
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN = "gen.pth"
CHECKPOINT_DISC = "disc.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 10000
BATCH_SIZE = 16
LAMBDA_GP = 10
NUM_WORKERS = 4
HIGH_RES = 128
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3
highres_transform = A.Compose(
[
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
ToTensorV2(),
]
)
lowres_transform = A.Compose(
[
A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
ToTensorV2(),
]
)
both_transforms = A.Compose(
[
A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
A.HorizontalFlip(p=0.5),
A.RandomRotate90(p=0.5),
]
)
test_transform = A.Compose(
[
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
ToTensorV2(),
]
)

View File

@@ -0,0 +1,49 @@
import torch
from tqdm import tqdm
import time
import torch.nn
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import config
from PIL import Image
import cv2
class MyImageFolder(Dataset):
def __init__(self, root_dir):
super(MyImageFolder, self).__init__()
self.data = []
self.root_dir = root_dir
self.class_names = os.listdir(root_dir)
for index, name in enumerate(self.class_names):
files = os.listdir(os.path.join(root_dir, name))
self.data += list(zip(files, [index] * len(files)))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img_file, label = self.data[index]
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
image = cv2.imread(os.path.join(root_and_dir, img_file))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
both_transform = config.both_transforms(image=image)["image"]
low_res = config.lowres_transform(image=both_transform)["image"]
high_res = config.highres_transform(image=both_transform)["image"]
return low_res, high_res
def test():
dataset = MyImageFolder(root_dir="data/")
loader = DataLoader(dataset, batch_size=8)
for low_res, high_res in loader:
print(low_res.shape)
print(high_res.shape)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,19 @@
import torch.nn as nn
from torchvision.models import vgg19
import config
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
self.vgg = vgg19(pretrained=True).features[:35].eval().to(config.DEVICE)
for param in self.vgg.parameters():
param.requires_grad = False
self.loss = nn.MSELoss()
def forward(self, input, target):
vgg_input_features = self.vgg(input)
vgg_target_features = self.vgg(target)
return self.loss(vgg_input_features, vgg_target_features)

View File

@@ -0,0 +1,154 @@
import torch
from torch import nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, use_act, **kwargs):
super().__init__()
self.cnn = nn.Conv2d(
in_channels,
out_channels,
**kwargs,
bias=True,
)
self.act = nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity()
def forward(self, x):
return self.act(self.cnn(x))
class UpsampleBlock(nn.Module):
def __init__(self, in_c, scale_factor=2):
super().__init__()
self.upsample = nn.Upsample(scale_factor=scale_factor, mode="nearest")
self.conv = nn.Conv2d(in_c, in_c, 3, 1, 1, bias=True)
self.act = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
return self.act(self.conv(self.upsample(x)))
class DenseResidualBlock(nn.Module):
def __init__(self, in_channels, channels=32, residual_beta=0.2):
super().__init__()
self.residual_beta = residual_beta
self.blocks = nn.ModuleList()
for i in range(5):
self.blocks.append(
ConvBlock(
in_channels + channels * i,
channels if i <= 3 else in_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=True if i <= 3 else False,
)
)
def forward(self, x):
new_inputs = x
for block in self.blocks:
out = block(new_inputs)
new_inputs = torch.cat([new_inputs, out], dim=1)
return self.residual_beta * out + x
class RRDB(nn.Module):
def __init__(self, in_channels, residual_beta=0.2):
super().__init__()
self.residual_beta = residual_beta
self.rrdb = nn.Sequential(*[DenseResidualBlock(in_channels) for _ in range(3)])
def forward(self, x):
return self.rrdb(x) * self.residual_beta + x
class Generator(nn.Module):
def __init__(self, in_channels=3, num_channels=64, num_blocks=23):
super().__init__()
self.initial = nn.Conv2d(
in_channels,
num_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.residuals = nn.Sequential(*[RRDB(num_channels) for _ in range(num_blocks)])
self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
self.upsamples = nn.Sequential(
UpsampleBlock(num_channels), UpsampleBlock(num_channels),
)
self.final = nn.Sequential(
nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_channels, in_channels, 3, 1, 1, bias=True),
)
def forward(self, x):
initial = self.initial(x)
x = self.conv(self.residuals(initial)) + initial
x = self.upsamples(x)
return self.final(x)
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
super().__init__()
blocks = []
for idx, feature in enumerate(features):
blocks.append(
ConvBlock(
in_channels,
feature,
kernel_size=3,
stride=1 + idx % 2,
padding=1,
use_act=True,
),
)
in_channels = feature
self.blocks = nn.Sequential(*blocks)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((6, 6)),
nn.Flatten(),
nn.Linear(512 * 6 * 6, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 1),
)
def forward(self, x):
x = self.blocks(x)
return self.classifier(x)
def initialize_weights(model, scale=0.1):
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
m.weight.data *= scale
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight.data)
m.weight.data *= scale
def test():
gen = Generator()
disc = Discriminator()
low_res = 24
x = torch.randn((5, 3, low_res, low_res))
gen_out = gen(x)
disc_out = disc(gen_out)
print(gen_out.shape)
print(disc_out.shape)
if __name__ == "__main__":
test()

Binary file not shown.

After

Width:  |  Height:  |  Size: 543 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 375 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 199 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@@ -0,0 +1,154 @@
import torch
import config
from torch import nn
from torch import optim
from utils import gradient_penalty, load_checkpoint, save_checkpoint, plot_examples
from loss import VGGLoss
from torch.utils.data import DataLoader
from model import Generator, Discriminator, initialize_weights
from tqdm import tqdm
from dataset import MyImageFolder
from torch.utils.tensorboard import SummaryWriter
torch.backends.cudnn.benchmark = True
def train_fn(
loader,
disc,
gen,
opt_gen,
opt_disc,
l1,
vgg_loss,
g_scaler,
d_scaler,
writer,
tb_step,
):
loop = tqdm(loader, leave=True)
for idx, (low_res, high_res) in enumerate(loop):
high_res = high_res.to(config.DEVICE)
low_res = low_res.to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(low_res)
critic_real = disc(high_res)
critic_fake = disc(fake.detach())
gp = gradient_penalty(disc, high_res, fake, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
)
opt_disc.zero_grad()
d_scaler.scale(loss_critic).backward()
d_scaler.step(opt_disc)
d_scaler.update()
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
with torch.cuda.amp.autocast():
l1_loss = 1e-2 * l1(fake, high_res)
adversarial_loss = 5e-3 * -torch.mean(disc(fake))
loss_for_vgg = vgg_loss(fake, high_res)
gen_loss = l1_loss + loss_for_vgg + adversarial_loss
opt_gen.zero_grad()
g_scaler.scale(gen_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
writer.add_scalar("Critic loss", loss_critic.item(), global_step=tb_step)
tb_step += 1
if idx % 100 == 0 and idx > 0:
plot_examples("test_images/", gen)
loop.set_postfix(
gp=gp.item(),
critic=loss_critic.item(),
l1=l1_loss.item(),
vgg=loss_for_vgg.item(),
adversarial=adversarial_loss.item(),
)
return tb_step
def main():
dataset = MyImageFolder(root_dir="data/")
loader = DataLoader(
dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
pin_memory=True,
num_workers=config.NUM_WORKERS,
)
gen = Generator(in_channels=3).to(config.DEVICE)
disc = Discriminator(in_channels=3).to(config.DEVICE)
initialize_weights(gen)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
writer = SummaryWriter("logs")
tb_step = 0
l1 = nn.L1Loss()
gen.train()
disc.train()
vgg_loss = VGGLoss()
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN,
gen,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISC,
disc,
opt_disc,
config.LEARNING_RATE,
)
for epoch in range(config.NUM_EPOCHS):
tb_step = train_fn(
loader,
disc,
gen,
opt_gen,
opt_disc,
l1,
vgg_loss,
g_scaler,
d_scaler,
writer,
tb_step,
)
if config.SAVE_MODEL:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
if __name__ == "__main__":
try_model = True
if try_model:
# Will just use pretrained weights and run on images
# in test_images/ and save the ones to SR in saved/
gen = Generator(in_channels=3).to(config.DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
load_checkpoint(
config.CHECKPOINT_GEN,
gen,
opt_gen,
config.LEARNING_RATE,
)
plot_examples("test_images/", gen)
else:
# This will train from scratch
main()

View File

@@ -0,0 +1,67 @@
import torch
import os
import config
import numpy as np
from PIL import Image
from torchvision.utils import save_image
def gradient_penalty(critic, real, fake, device):
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake.detach() * (1 - alpha)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
# model.load_state_dict(checkpoint)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def plot_examples(low_res_folder, gen):
files = os.listdir(low_res_folder)
gen.eval()
for file in files:
image = Image.open("test_images/" + file)
with torch.no_grad():
upscaled_img = gen(
config.test_transform(image=np.asarray(image))["image"]
.unsqueeze(0)
.to(config.DEVICE)
)
save_image(upscaled_img, f"saved/{file}")
gen.train()

Binary file not shown.

After

Width:  |  Height:  |  Size: 1013 KiB

View File

@@ -0,0 +1,47 @@
import torch
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN = "gen.pth.tar"
CHECKPOINT_DISC = "disc.pth.tar"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100
BATCH_SIZE = 16
NUM_WORKERS = 4
HIGH_RES = 96
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3
highres_transform = A.Compose(
[
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2(),
]
)
lowres_transform = A.Compose(
[
A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
ToTensorV2(),
]
)
both_transforms = A.Compose(
[
A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
A.HorizontalFlip(p=0.5),
A.RandomRotate90(p=0.5),
]
)
test_transform = A.Compose(
[
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
ToTensorV2(),
]
)

View File

@@ -0,0 +1,43 @@
import os
import numpy as np
import config
from torch.utils.data import Dataset, DataLoader
from PIL import Image
class MyImageFolder(Dataset):
def __init__(self, root_dir):
super(MyImageFolder, self).__init__()
self.data = []
self.root_dir = root_dir
self.class_names = os.listdir(root_dir)
for index, name in enumerate(self.class_names):
files = os.listdir(os.path.join(root_dir, name))
self.data += list(zip(files, [index] * len(files)))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img_file, label = self.data[index]
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
image = config.both_transforms(image=image)["image"]
high_res = config.highres_transform(image=image)["image"]
low_res = config.lowres_transform(image=image)["image"]
return low_res, high_res
def test():
dataset = MyImageFolder(root_dir="new_data/")
loader = DataLoader(dataset, batch_size=1, num_workers=8)
for low_res, high_res in loader:
print(low_res.shape)
print(high_res.shape)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,21 @@
import torch.nn as nn
from torchvision.models import vgg19
import config
# phi_5,4 5th conv layer before maxpooling but after activation
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE)
self.loss = nn.MSELoss()
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, input, target):
vgg_input_features = self.vgg(input)
vgg_target_features = self.vgg(target)
return self.loss(vgg_input_features, vgg_target_features)

View File

@@ -0,0 +1,128 @@
import torch
from torch import nn
class ConvBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
discriminator=False,
use_act=True,
use_bn=True,
**kwargs,
):
super().__init__()
self.use_act = use_act
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
self.act = (
nn.LeakyReLU(0.2, inplace=True)
if discriminator
else nn.PReLU(num_parameters=out_channels)
)
def forward(self, x):
return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
class UpsampleBlock(nn.Module):
def __init__(self, in_c, scale_factor):
super().__init__()
self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
self.ps = nn.PixelShuffle(scale_factor) # in_c * 4, H, W --> in_c, H*2, W*2
self.act = nn.PReLU(num_parameters=in_c)
def forward(self, x):
return self.act(self.ps(self.conv(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.block1 = ConvBlock(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
self.block2 = ConvBlock(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=False,
)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
return out + x
class Generator(nn.Module):
def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
super().__init__()
self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)
def forward(self, x):
initial = self.initial(x)
x = self.residuals(initial)
x = self.convblock(x) + initial
x = self.upsamples(x)
return torch.tanh(self.final(x))
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
super().__init__()
blocks = []
for idx, feature in enumerate(features):
blocks.append(
ConvBlock(
in_channels,
feature,
kernel_size=3,
stride=1 + idx % 2,
padding=1,
discriminator=True,
use_act=True,
use_bn=False if idx == 0 else True,
)
)
in_channels = feature
self.blocks = nn.Sequential(*blocks)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((6, 6)),
nn.Flatten(),
nn.Linear(512*6*6, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 1),
)
def forward(self, x):
x = self.blocks(x)
return self.classifier(x)
def test():
low_resolution = 24 # 96x96 -> 24x24
with torch.cuda.amp.autocast():
x = torch.randn((5, 3, low_resolution, low_resolution))
gen = Generator()
gen_out = gen(x)
disc = Discriminator()
disc_out = disc(gen_out)
print(gen_out.shape)
print(disc_out.shape)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,88 @@
import torch
import config
from torch import nn
from torch import optim
from utils import load_checkpoint, save_checkpoint, plot_examples
from loss import VGGLoss
from torch.utils.data import DataLoader
from model import Generator, Discriminator
from tqdm import tqdm
from dataset import MyImageFolder
torch.backends.cudnn.benchmark = True
def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
loop = tqdm(loader, leave=True)
for idx, (low_res, high_res) in enumerate(loop):
high_res = high_res.to(config.DEVICE)
low_res = low_res.to(config.DEVICE)
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
fake = gen(low_res)
disc_real = disc(high_res)
disc_fake = disc(fake.detach())
disc_loss_real = bce(
disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
)
disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
loss_disc = disc_loss_fake + disc_loss_real
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
disc_fake = disc(fake)
#l2_loss = mse(fake, high_res)
adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
gen_loss = loss_for_vgg + adversarial_loss
opt_gen.zero_grad()
gen_loss.backward()
opt_gen.step()
if idx % 200 == 0:
plot_examples("test_images/", gen)
def main():
dataset = MyImageFolder(root_dir="new_data/")
loader = DataLoader(
dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
pin_memory=True,
num_workers=config.NUM_WORKERS,
)
gen = Generator(in_channels=3).to(config.DEVICE)
disc = Discriminator(img_channels=3).to(config.DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg_loss = VGGLoss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN,
gen,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
)
for epoch in range(config.NUM_EPOCHS):
train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)
if config.SAVE_MODEL:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,66 @@
import torch
import os
import config
import numpy as np
from PIL import Image
from torchvision.utils import save_image
def gradient_penalty(critic, real, fake, device):
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake.detach() * (1 - alpha)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def plot_examples(low_res_folder, gen):
files = os.listdir(low_res_folder)
gen.eval()
for file in files:
image = Image.open("test_images/" + file)
with torch.no_grad():
upscaled_img = gen(
config.test_transform(image=np.asarray(image))["image"]
.unsqueeze(0)
.to(config.DEVICE)
)
save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
gen.train()

View File

@@ -0,0 +1,25 @@
import albumentations as A
import cv2
import torch
from math import log2
from albumentations.pytorch import ToTensorV2
#from utils import seed_everything
START_TRAIN_AT_IMG_SIZE = 32
DATASET = 'FFHQ_32'
CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LOAD_MODEL = False
SAVE_MODEL = True
LEARNING_RATE = 1e-3
BATCH_SIZES = [32, 32, 32, 32, 32, 16, 8, 4, 2]
CHANNELS_IMG = 3
Z_DIM = 512
W_DIM = 512
IN_CHANNELS = 512
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [50] * 100
FIXED_NOISE = torch.randn((8, Z_DIM)).to(DEVICE)
NUM_WORKERS = 6

View File

@@ -0,0 +1,9 @@
import os
from PIL import Image
from tqdm import tqdm
root_dir = "FFHQ/images1024x1024"
for file in tqdm(os.listdir(root_dir)):
img = Image.open(root_dir+ "/"+file).resize((128, 128))
img.save("FFHQ_resized/"+file)

View File

@@ -0,0 +1,293 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log2
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
class PixelNorm(nn.Module):
def __init__(self):
super(PixelNorm, self).__init__()
self.epsilon = 1e-8
def forward(self, x):
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
class MappingNetwork(nn.Module):
def __init__(self, z_dim, w_dim):
super().__init__()
self.mapping = nn.Sequential(
PixelNorm(),
WSLinear(z_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
)
def forward(self, x):
return self.mapping(x)
class InjectNoise(nn.Module):
def __init__(self, channels):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x):
noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
return x + self.weight * noise
class AdaIN(nn.Module):
def __init__(self, channels, w_dim):
super().__init__()
self.instance_norm = nn.InstanceNorm2d(channels)
self.style_scale = WSLinear(w_dim, channels)
self.style_bias = WSLinear(w_dim, channels)
def forward(self, x, w):
x = self.instance_norm(x)
style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
return style_scale * x + style_bias
class WSConv2d(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2,
):
super(WSConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
self.bias = self.conv.bias
self.conv.bias = None
# initialize conv layer
nn.init.normal_(self.conv.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
class WSLinear(nn.Module):
def __init__(
self, in_features, out_features, gain=2,
):
super(WSLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self.scale = (gain / in_features)**0.5
self.bias = self.linear.bias
self.linear.bias = None
# initialize linear layer
nn.init.normal_(self.linear.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.linear(x * self.scale) + self.bias
class GenBlock(nn.Module):
def __init__(self, in_channels, out_channels, w_dim):
super(GenBlock, self).__init__()
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.inject_noise1 = InjectNoise(out_channels)
self.inject_noise2 = InjectNoise(out_channels)
self.adain1 = AdaIN(out_channels, w_dim)
self.adain2 = AdaIN(out_channels, w_dim)
def forward(self, x, w):
x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
return x
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2)
def forward(self, x):
x = self.leaky(self.conv1(x))
x = self.leaky(self.conv2(x))
return x
class Generator(nn.Module):
def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
super(Generator, self).__init__()
self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
self.map = MappingNetwork(z_dim, w_dim)
self.initial_adain1 = AdaIN(in_channels, w_dim)
self.initial_adain2 = AdaIN(in_channels, w_dim)
self.initial_noise1 = InjectNoise(in_channels)
self.initial_noise2 = InjectNoise(in_channels)
self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.initial_rgb = WSConv2d(
in_channels, img_channels, kernel_size=1, stride=1, padding=0
)
self.prog_blocks, self.rgb_layers = (
nn.ModuleList([]),
nn.ModuleList([self.initial_rgb]),
)
for i in range(len(factors) - 1): # -1 to prevent index error because of factors[i+1]
conv_in_c = int(in_channels * factors[i])
conv_out_c = int(in_channels * factors[i + 1])
self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
self.rgb_layers.append(
WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
)
def fade_in(self, alpha, upscaled, generated):
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
def forward(self, noise, alpha, steps):
w = self.map(noise)
x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
x = self.initial_conv(x)
out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)
if steps == 0:
return self.initial_rgb(x)
for step in range(steps):
upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
out = self.prog_blocks[step](upscaled, w)
# The number of channels in upscale will stay the same, while
# out which has moved through prog_blocks might change. To ensure
# we can convert both to rgb we use different rgb_layers
# (steps-1) and steps for upscaled, out respectively
final_upscaled = self.rgb_layers[steps - 1](upscaled)
final_out = self.rgb_layers[steps](out)
return self.fade_in(alpha, final_upscaled, final_out)
class Discriminator(nn.Module):
def __init__(self, in_channels, img_channels=3):
super(Discriminator, self).__init__()
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
self.leaky = nn.LeakyReLU(0.2)
# here we work back ways from factors because the discriminator
# should be mirrored from the generator. So the first prog_block and
# rgb layer we append will work for input size 1024x1024, then 512->256-> etc
for i in range(len(factors) - 1, 0, -1):
conv_in = int(in_channels * factors[i])
conv_out = int(in_channels * factors[i - 1])
self.prog_blocks.append(ConvBlock(conv_in, conv_out))
self.rgb_layers.append(
WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
)
# perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
# did this to "mirror" the generator initial_rgb
self.initial_rgb = WSConv2d(
img_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.rgb_layers.append(self.initial_rgb)
self.avg_pool = nn.AvgPool2d(
kernel_size=2, stride=2
) # down sampling using avg pool
# this is the block for 4x4 input size
self.final_block = nn.Sequential(
# +1 to in_channels because we concatenate from MiniBatch std
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
nn.LeakyReLU(0.2),
WSConv2d(
in_channels, 1, kernel_size=1, padding=0, stride=1
), # we use this instead of linear layer
)
def fade_in(self, alpha, downscaled, out):
"""Used to fade in downscaled using avg pooling and output from CNN"""
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
return alpha * out + (1 - alpha) * downscaled
def minibatch_std(self, x):
batch_statistics = (
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
)
# we take the std for each example (across all channels, and pixels) then we repeat it
# for a single channel and concatenate it with the image. In this way the discriminator
# will get information about the variation in the batch/image
return torch.cat([x, batch_statistics], dim=1)
def forward(self, x, alpha, steps):
# where we should start in the list of prog_blocks, maybe a bit confusing but
# the last is for the 4x4. So example let's say steps=1, then we should start
# at the second to last because input_size will be 8x8. If steps==0 we just
# use the final block
cur_step = len(self.prog_blocks) - steps
# convert from rgb as initial step, this will depend on
# the image size (each will have it's on rgb layer)
out = self.leaky(self.rgb_layers[cur_step](x))
if steps == 0: # i.e, image is 4x4
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
# because prog_blocks might change the channels, for down scale we use rgb_layer
# from previous/smaller size which in our case correlates to +1 in the indexing
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
out = self.avg_pool(self.prog_blocks[cur_step](out))
# the fade_in is done first between the downscaled and the input
# this is opposite from the generator
out = self.fade_in(alpha, downscaled, out)
for step in range(cur_step + 1, len(self.prog_blocks)):
out = self.prog_blocks[step](out)
out = self.avg_pool(out)
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
if __name__ == "__main__":
Z_DIM = 512
W_DIM = 512
IN_CHANNELS = 512
gen = Generator(Z_DIM, W_DIM, IN_CHANNELS, img_channels=3).to("cuda")
disc = Discriminator(IN_CHANNELS, img_channels=3).to("cuda")
tot = 0
for param in gen.parameters():
tot += param.numel()
print(tot)
import sys
sys.exit()
for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
num_steps = int(log2(img_size / 4))
x = torch.randn((2, Z_DIM)).to("cuda")
z = gen(x, 0.5, steps=num_steps)
assert z.shape == (2, 3, img_size, img_size)
out = disc(z, alpha=0.5, steps=num_steps)
assert out.shape == (2, 1)
print(f"Success! At img size: {img_size}")

View File

@@ -0,0 +1,30 @@
from PIL import Image
from tqdm import tqdm
import os
from multiprocessing import Pool
root_dir = "FFHQ/images1024x1024/"
files = os.listdir(root_dir)
def resize(file, size, folder_to_save):
image = Image.open(root_dir + file).resize((size, size), Image.LANCZOS)
image.save(folder_to_save+file, quality=100)
if __name__ == "__main__":
for img_size in [4, 8, 512, 1024]:
folder_name = "FFHQ_"+str(img_size)+"/images/"
if not os.path.isdir(folder_name):
os.makedirs(folder_name)
data = [(file, img_size, folder_name) for file in files]
pool = Pool()
pool.starmap(resize, data)

View File

@@ -0,0 +1,2 @@
this implementation doesn't work, I need to debug and see where I've made a mistake. It seems to do something
that makes sense but it's nowhere near the same level of performance that they had in the original paper.

View File

@@ -0,0 +1,200 @@
""" Training of ProGAN using WGAN-GP loss"""
import torch
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import (
gradient_penalty,
plot_to_tensorboard,
save_checkpoint,
load_checkpoint,
EMA,
)
from model import Discriminator, Generator
from math import log2
from tqdm import tqdm
import config
torch.backends.cudnn.benchmarks = True
def get_loader(image_size):
transform = transforms.Compose(
[
#transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
[0.5 for _ in range(config.CHANNELS_IMG)],
[0.5 for _ in range(config.CHANNELS_IMG)],
),
]
)
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
return loader, dataset
def train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen,
tensorboard_step,
writer,
scaler_gen,
scaler_critic,
ema,
):
loop = tqdm(loader, leave=True)
gen2 = Generator(
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE)
cur_batch_size = real.shape[0]
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
opt_critic.zero_grad()
scaler_critic.scale(loss_critic).backward()
scaler_critic.step(opt_critic)
scaler_critic.update()
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast():
gen_fake = critic(fake, alpha, step)
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen)
scaler_gen.update()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
)
alpha = min(alpha, 1)
if batch_idx % 100 == 0:
ema(gen)
with torch.no_grad():
ema.copy_weights_to(gen2)
fixed_fakes = gen2(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
plot_to_tensorboard(
writer,
loss_critic.item(),
loss_gen.item(),
real.detach(),
fixed_fakes.detach(),
tensorboard_step,
)
tensorboard_step += 1
loop.set_postfix(
gp=gp.item(),
loss_critic=loss_critic.item(),
)
return tensorboard_step, alpha
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
# but really who cares..
gen = Generator(
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
ema = EMA(gamma=0.999, save_frequency=2000)
# initialize optimizers and scalers for FP16 training
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
{"params": gen.map.parameters(), "lr": 1e-5}], lr=config.LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan")
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE,
)
gen.train()
critic.train()
tensorboard_step = 0
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
print(f"Current image size: {4 * 2 ** step}")
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
tensorboard_step, alpha = train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen,
tensorboard_step,
writer,
scaler_gen,
scaler_critic,
ema,
)
if config.SAVE_MODEL:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
step += 1 # progress to the next img size
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,126 @@
import torch
import random
import numpy as np
import os
import torchvision
import torch.nn as nn
import warnings
# Print losses occasionally and print to tensorboard
def plot_to_tensorboard(
writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
with torch.no_grad():
# take out (up to) 32 examples
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * beta + fake.detach() * (1 - beta)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images, alpha, train_step)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location="cuda")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class EMA:
# Found this useful (thanks alexis-jacq):
# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/3
def __init__(self, gamma=0.99, save=True, save_frequency=100, save_filename="ema_weights.pth"):
"""
Initialize the weight to which we will do the
exponential moving average and the dictionary
where we store the model parameters
"""
self.gamma = gamma
self.registered = {}
self.save_filename = save_filename
self.save_frequency = save_frequency
self.count = 0
if save_filename in os.listdir("."):
self.registered = torch.load(self.save_filename)
if not save:
warnings.warn("Note that the exponential moving average weights will not be saved to a .pth file!")
def register_weights(self, model):
"""
Registers the weights of the model which will
later be used when we take the moving average
"""
for name, param in model.named_parameters():
if param.requires_grad:
self.registered[name] = param.clone().detach()
def __call__(self, model):
self.count += 1
for name, param in model.named_parameters():
if param.requires_grad:
new_weight = param.clone().detach() if name not in self.registered else self.gamma * param + (1 - self.gamma) * self.registered[name]
self.registered[name] = new_weight
if self.count % self.save_frequency == 0:
self.save_ema_weights()
def copy_weights_to(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
param.data = self.registered[name]
def save_ema_weights(self):
torch.save(self.registered, self.save_filename)