mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
updated progan
This commit is contained in:
@@ -60,24 +60,19 @@ def train_fn(
|
||||
scaler_critic,
|
||||
):
|
||||
loop = tqdm(loader, leave=True)
|
||||
# critic_losses = []
|
||||
reals = 0
|
||||
fakes = 0
|
||||
for batch_idx, (real, _) in enumerate(loop):
|
||||
real = real.to(config.DEVICE)
|
||||
cur_batch_size = real.shape[0]
|
||||
|
||||
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
|
||||
# which is equivalent to minimizing the negative of the expression
|
||||
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
|
||||
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
fake = gen(noise, alpha, step)
|
||||
critic_real = critic(real, alpha, step)
|
||||
critic_fake = critic(fake.detach(), alpha, step)
|
||||
reals += critic_real.mean().item()
|
||||
fakes += critic_fake.mean().item()
|
||||
gp = gradient_penalty(critic, real, fake, device=config.DEVICE)
|
||||
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
|
||||
loss_critic = (
|
||||
-(torch.mean(critic_real) - torch.mean(critic_fake))
|
||||
+ config.LAMBDA_GP * gp
|
||||
@@ -119,8 +114,6 @@ def train_fn(
|
||||
tensorboard_step += 1
|
||||
|
||||
loop.set_postfix(
|
||||
reals=reals / (batch_idx + 1),
|
||||
fakes=fakes / (batch_idx + 1),
|
||||
gp=gp.item(),
|
||||
loss_critic=loss_critic.item(),
|
||||
)
|
||||
@@ -131,11 +124,12 @@ def train_fn(
|
||||
def main():
|
||||
# initialize gen and disc, note: discriminator should be called critic,
|
||||
# according to WGAN paper (since it no longer outputs between [0, 1])
|
||||
# but really who cares..
|
||||
gen = Generator(
|
||||
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
critic = Discriminator(
|
||||
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
|
||||
).to(config.DEVICE)
|
||||
|
||||
# initialize optimizers and scalers for FP16 training
|
||||
@@ -147,7 +141,7 @@ def main():
|
||||
scaler_gen = torch.cuda.amp.GradScaler()
|
||||
|
||||
# for tensorboard plotting
|
||||
writer = SummaryWriter(f"logs/gan")
|
||||
writer = SummaryWriter(f"logs/gan1")
|
||||
|
||||
if config.LOAD_MODEL:
|
||||
load_checkpoint(
|
||||
@@ -163,10 +157,6 @@ def main():
|
||||
tensorboard_step = 0
|
||||
# start at step that corresponds to img size that we set in config
|
||||
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
|
||||
|
||||
generate_examples(gen, step)
|
||||
import sys
|
||||
sys.exit()
|
||||
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
|
||||
alpha = 1e-5 # start with very low alpha
|
||||
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
|
||||
@@ -197,4 +187,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user