use training set len

This commit is contained in:
rasbt
2024-04-29 21:50:07 -05:00
parent a5b353667d
commit 354bb35726
4 changed files with 35 additions and 23 deletions

View File

@@ -123,6 +123,9 @@ def instantiate_model(choose_model, load_weights):
}
BASE_CONFIG.update(model_configs[choose_model])
if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
if load_weights:
@@ -354,17 +357,20 @@ if __name__ == "__main__":
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = None
if args.context_length == "model_context_length":
max_length = model.pos_emb.weight.shape[0]
elif args.context_length == "longest_training_example":
max_length = None
train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
max_length = train_dataset.max_length
else:
try:
max_length = int(args.context_length)
except ValueError:
raise ValueError("Invalid --context_length argument")
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
if train_dataset is None:
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)