mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
committed by
GitHub
parent
18c6b970ab
commit
6d175a22df
@@ -1,4 +1,4 @@
|
||||
# Additional Experiments Classifying the Sentiment of 50k IMDB Movie Reviews
|
||||
# Additional Experiments Classifying the Sentiment of 50k IMDb Movie Reviews
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
|
||||
class IMDBDataset(Dataset):
|
||||
class IMDbDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, use_attention_mask=False):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
|
||||
@@ -375,21 +375,21 @@ if __name__ == "__main__":
|
||||
else:
|
||||
raise ValueError("Invalid argument for `use_attention_mask`.")
|
||||
|
||||
train_dataset = IMDBDataset(
|
||||
train_dataset = IMDbDataset(
|
||||
base_path / "train.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
val_dataset = IMDBDataset(
|
||||
val_dataset = IMDbDataset(
|
||||
base_path / "validation.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
test_dataset = IMDBDataset(
|
||||
test_dataset = IMDbDataset(
|
||||
base_path / "test.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -17,7 +17,7 @@ from gpt_download import download_and_load_gpt2
|
||||
from previous_chapters import GPTModel, load_weights_into_gpt
|
||||
|
||||
|
||||
class IMDBDataset(Dataset):
|
||||
class IMDbDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
|
||||
@@ -368,7 +368,7 @@ if __name__ == "__main__":
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
train_dataset = IMDbDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
max_length = train_dataset.max_length
|
||||
else:
|
||||
try:
|
||||
@@ -377,9 +377,9 @@ if __name__ == "__main__":
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
|
||||
if train_dataset is None:
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
train_dataset = IMDbDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = IMDbDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
test_dataset = IMDbDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
|
||||
num_workers = 0
|
||||
batch_size = 8
|
||||
|
||||
Reference in New Issue
Block a user