mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
use training set len
This commit is contained in:
@@ -68,7 +68,7 @@ def partition_and_save(df, sizes=(35000, 5000, 10000)):
|
||||
|
||||
# Save to CSV files
|
||||
train.to_csv("train.csv", index=False)
|
||||
val.to_csv("val.csv", index=False)
|
||||
val.to_csv("validation.csv", index=False)
|
||||
test.to_csv("test.csv", index=False)
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +67,9 @@ def instantiate_model(choose_model, load_weights):
|
||||
}
|
||||
|
||||
BASE_CONFIG.update(model_configs[choose_model])
|
||||
|
||||
if not load_weights:
|
||||
torch.manual_seed(123)
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
|
||||
if load_weights:
|
||||
@@ -294,18 +297,21 @@ if __name__ == "__main__":
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_dataset = None
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
max_length = None
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
max_length = train_dataset.max_length
|
||||
else:
|
||||
try:
|
||||
max_length = int(args.context_length)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = IMDBDataset(base_path / "val.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
if train_dataset is None:
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
|
||||
num_workers = 0
|
||||
|
||||
Reference in New Issue
Block a user