use training set len

This commit is contained in:
rasbt
2024-04-29 21:50:07 -05:00
parent 97ed38116a
commit 0ac19a1e50
4 changed files with 35 additions and 23 deletions

View File

@@ -68,7 +68,7 @@ def partition_and_save(df, sizes=(35000, 5000, 10000)):
# Save to CSV files
train.to_csv("train.csv", index=False)
val.to_csv("val.csv", index=False)
val.to_csv("validation.csv", index=False)
test.to_csv("test.csv", index=False)

View File

@@ -67,6 +67,9 @@ def instantiate_model(choose_model, load_weights):
}
BASE_CONFIG.update(model_configs[choose_model])
if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
if load_weights:
@@ -294,18 +297,21 @@ if __name__ == "__main__":
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = None
if args.context_length == "model_context_length":
max_length = model.pos_emb.weight.shape[0]
elif args.context_length == "longest_training_example":
max_length = None
train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
max_length = train_dataset.max_length
else:
try:
max_length = int(args.context_length)
except ValueError:
raise ValueError("Invalid --context_length argument")
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
val_dataset = IMDBDataset(base_path / "val.csv", max_length=max_length, tokenizer=tokenizer)
if train_dataset is None:
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
num_workers = 0