mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
rename hparams to settings
This commit is contained in:
@@ -124,7 +124,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
|
||||
# plt.show()
|
||||
|
||||
|
||||
def main(gpt_config, hparams):
|
||||
def main(gpt_config, settings):
|
||||
|
||||
torch.manual_seed(123)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -152,7 +152,7 @@ def main(gpt_config, hparams):
|
||||
model = GPTModel(gpt_config)
|
||||
model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=hparams["learning_rate"], weight_decay=hparams["weight_decay"]
|
||||
model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"]
|
||||
)
|
||||
|
||||
##############################
|
||||
@@ -165,7 +165,7 @@ def main(gpt_config, hparams):
|
||||
|
||||
train_loader = create_dataloader_v1(
|
||||
text_data[:split_idx],
|
||||
batch_size=hparams["batch_size"],
|
||||
batch_size=settings["batch_size"],
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=True,
|
||||
@@ -174,7 +174,7 @@ def main(gpt_config, hparams):
|
||||
|
||||
val_loader = create_dataloader_v1(
|
||||
text_data[split_idx:],
|
||||
batch_size=hparams["batch_size"],
|
||||
batch_size=settings["batch_size"],
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=False,
|
||||
@@ -187,7 +187,7 @@ def main(gpt_config, hparams):
|
||||
|
||||
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=hparams["num_epochs"], eval_freq=5, eval_iter=1,
|
||||
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
|
||||
start_context="Every effort moves you",
|
||||
)
|
||||
|
||||
@@ -206,7 +206,7 @@ if __name__ == "__main__":
|
||||
"qkv_bias": False # Query-key-value bias
|
||||
}
|
||||
|
||||
OTHER_HPARAMS = {
|
||||
OTHER_SETTINGS = {
|
||||
"learning_rate": 5e-4,
|
||||
"num_epochs": 10,
|
||||
"batch_size": 2,
|
||||
@@ -217,14 +217,14 @@ if __name__ == "__main__":
|
||||
# Initiate training
|
||||
###########################
|
||||
|
||||
train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_HPARAMS)
|
||||
train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_SETTINGS)
|
||||
|
||||
###########################
|
||||
# After training
|
||||
###########################
|
||||
|
||||
# Plot results
|
||||
epochs_tensor = torch.linspace(0, OTHER_HPARAMS["num_epochs"], len(train_losses))
|
||||
epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses))
|
||||
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
|
||||
plt.savefig("loss.pdf")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user