mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
|
|
import torch
|
||
|
|
import config
|
||
|
|
from torchvision.utils import save_image
|
||
|
|
|
||
|
|
def save_some_examples(gen, val_loader, epoch, folder):
|
||
|
|
x, y = next(iter(val_loader))
|
||
|
|
x, y = x.to(config.DEVICE), y.to(config.DEVICE)
|
||
|
|
gen.eval()
|
||
|
|
with torch.no_grad():
|
||
|
|
y_fake = gen(x)
|
||
|
|
y_fake = y_fake * 0.5 + 0.5 # remove normalization#
|
||
|
|
save_image(y_fake, folder + f"/y_gen_{epoch}.png")
|
||
|
|
save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
|
||
|
|
if epoch == 1:
|
||
|
|
save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
|
||
|
|
gen.train()
|
||
|
|
|
||
|
|
|
||
|
|
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||
|
|
print("=> Saving checkpoint")
|
||
|
|
checkpoint = {
|
||
|
|
"state_dict": model.state_dict(),
|
||
|
|
"optimizer": optimizer.state_dict(),
|
||
|
|
}
|
||
|
|
torch.save(checkpoint, filename)
|
||
|
|
|
||
|
|
|
||
|
|
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||
|
|
print("=> Loading checkpoint")
|
||
|
|
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
|
||
|
|
model.load_state_dict(checkpoint["state_dict"])
|
||
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||
|
|
|
||
|
|
# If we don't do this then it will just have learning rate of old checkpoint
|
||
|
|
# and it will lead to many hours of debugging \:
|
||
|
|
for param_group in optimizer.param_groups:
|
||
|
|
param_group["lr"] = lr
|
||
|
|
|
||
|
|
|