diff --git a/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py index 8a738d5..d1763da 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py @@ -119,11 +119,11 @@ def train_model_simple(model, optimizer, device, n_epochs, print(f"Ep {epoch+1} (Step {global_step}): " f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") - # Generate text passage - if index % print_sample_iter == 0: - generate_and_print_sample( - model, train_loader.dataset.tokenizer, device, start_context - ) + # Generate text passage + if global_step % print_sample_iter == 0: + generate_and_print_sample( + model, train_loader.dataset.tokenizer, device, start_context + ) if global_step % save_ckpt_freq: file_name = output_dir / f"model_pg_{global_step}.pth" @@ -137,7 +137,7 @@ def train_model_simple(model, optimizer, device, n_epochs, torch.save(model.state_dict(), file_name) print(f"Saved {file_name}") - return train_losses, val_losses, tokens_seen + return train_losses, val_losses, track_tokens_seen if __name__ == "__main__": @@ -150,7 +150,7 @@ if __name__ == "__main__": help='Directory where the model checkpoints will be saved') parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train the model') - parser.add_argument('--print_sample_iter', type=int, default=500, + parser.add_argument('--print_sample_iter', type=int, default=1000, help='Iterations between printing sample outputs') parser.add_argument('--eval_freq', type=int, default=100, help='Frequency of evaluations during training') @@ -205,7 +205,9 @@ if __name__ == "__main__": start_context="Every effort moves you", ) - epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses)) + epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses)) + + print("debug", epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) torch.save(model.state_dict(), output_dir / "model_pg_final.pth") diff --git a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py index 0cd8d02..4641ba4 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py @@ -274,8 +274,9 @@ def generate_and_print_sample(model, tokenizer, device, start_context): context_size = model.pos_emb.weight.shape[0] encoded = text_to_token_ids(start_context, tokenizer).to(device) with torch.no_grad(): - token_ids = generate_text_simple(model=model, idx=encoded, - max_new_tokens=50, context_size=context_size) + token_ids = generate_text_simple( + model=model, idx=encoded, + max_new_tokens=50, context_size=context_size) decoded_text = token_ids_to_text(token_ids, tokenizer) print(decoded_text.replace("\n", " ")) # Compact print format model.train()