updated progan

This commit is contained in:
Aladdin Persson
2021-03-24 13:01:45 +01:00
parent 59b1de7bfe
commit 74597aa8fd
5 changed files with 15 additions and 26 deletions

View File

@@ -60,24 +60,19 @@ def train_fn(
scaler_critic,
):
loop = tqdm(loader, leave=True)
# critic_losses = []
reals = 0
fakes = 0
for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE)
cur_batch_size = real.shape[0]
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
# which is equivalent to minimizing the negative of the expression
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
reals += critic_real.mean().item()
fakes += critic_fake.mean().item()
gp = gradient_penalty(critic, real, fake, device=config.DEVICE)
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
@@ -119,8 +114,6 @@ def train_fn(
tensorboard_step += 1
loop.set_postfix(
reals=reals / (batch_idx + 1),
fakes=fakes / (batch_idx + 1),
gp=gp.item(),
loss_critic=loss_critic.item(),
)
@@ -131,11 +124,12 @@ def train_fn(
def main():
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
# but really who cares..
gen = Generator(
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
# initialize optimizers and scalers for FP16 training
@@ -147,7 +141,7 @@ def main():
scaler_gen = torch.cuda.amp.GradScaler()
# for tensorboard plotting
writer = SummaryWriter(f"logs/gan")
writer = SummaryWriter(f"logs/gan1")
if config.LOAD_MODEL:
load_checkpoint(
@@ -163,10 +157,6 @@ def main():
tensorboard_step = 0
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
generate_examples(gen, step)
import sys
sys.exit()
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
@@ -197,4 +187,4 @@ def main():
if __name__ == "__main__":
main()
main()