mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
|
|
import torch
|
||
|
|
import os
|
||
|
|
import config
|
||
|
|
import numpy as np
|
||
|
|
from PIL import Image
|
||
|
|
from torchvision.utils import save_image
|
||
|
|
|
||
|
|
|
||
|
|
def gradient_penalty(critic, real, fake, device):
|
||
|
|
BATCH_SIZE, C, H, W = real.shape
|
||
|
|
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
|
||
|
|
interpolated_images = real * alpha + fake.detach() * (1 - alpha)
|
||
|
|
interpolated_images.requires_grad_(True)
|
||
|
|
|
||
|
|
# Calculate critic scores
|
||
|
|
mixed_scores = critic(interpolated_images)
|
||
|
|
|
||
|
|
# Take the gradient of the scores with respect to the images
|
||
|
|
gradient = torch.autograd.grad(
|
||
|
|
inputs=interpolated_images,
|
||
|
|
outputs=mixed_scores,
|
||
|
|
grad_outputs=torch.ones_like(mixed_scores),
|
||
|
|
create_graph=True,
|
||
|
|
retain_graph=True,
|
||
|
|
)[0]
|
||
|
|
gradient = gradient.view(gradient.shape[0], -1)
|
||
|
|
gradient_norm = gradient.norm(2, dim=1)
|
||
|
|
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
||
|
|
return gradient_penalty
|
||
|
|
|
||
|
|
|
||
|
|
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||
|
|
print("=> Saving checkpoint")
|
||
|
|
checkpoint = {
|
||
|
|
"state_dict": model.state_dict(),
|
||
|
|
"optimizer": optimizer.state_dict(),
|
||
|
|
}
|
||
|
|
torch.save(checkpoint, filename)
|
||
|
|
|
||
|
|
|
||
|
|
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||
|
|
print("=> Loading checkpoint")
|
||
|
|
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
|
||
|
|
# model.load_state_dict(checkpoint)
|
||
|
|
model.load_state_dict(checkpoint["state_dict"])
|
||
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||
|
|
|
||
|
|
# If we don't do this then it will just have learning rate of old checkpoint
|
||
|
|
# and it will lead to many hours of debugging \:
|
||
|
|
for param_group in optimizer.param_groups:
|
||
|
|
param_group["lr"] = lr
|
||
|
|
|
||
|
|
|
||
|
|
def plot_examples(low_res_folder, gen):
|
||
|
|
files = os.listdir(low_res_folder)
|
||
|
|
|
||
|
|
gen.eval()
|
||
|
|
for file in files:
|
||
|
|
image = Image.open("test_images/" + file)
|
||
|
|
with torch.no_grad():
|
||
|
|
upscaled_img = gen(
|
||
|
|
config.test_transform(image=np.asarray(image))["image"]
|
||
|
|
.unsqueeze(0)
|
||
|
|
.to(config.DEVICE)
|
||
|
|
)
|
||
|
|
save_image(upscaled_img, f"saved/{file}")
|
||
|
|
gen.train()
|