mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
cyclegan
This commit is contained in:
35
ML/Pytorch/GANs/CycleGAN/utils.py
Normal file
35
ML/Pytorch/GANs/CycleGAN/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import random, torch, os, numpy as np
|
||||
import torch.nn as nn
|
||||
import config
|
||||
import copy
|
||||
|
||||
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
||||
print("=> Loading checkpoint")
|
||||
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
# If we don't do this then it will just have learning rate of old checkpoint
|
||||
# and it will lead to many hours of debugging \:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def seed_everything(seed=42):
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
Reference in New Issue
Block a user