mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Add new experiment without padding
This commit is contained in:
@@ -46,7 +46,7 @@ class LinearWithLoRA(torch.nn.Module):
|
||||
|
||||
|
||||
class SpamDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, skip_padding=False):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
|
||||
|
||||
@@ -55,11 +55,13 @@ class SpamDataset(Dataset):
|
||||
tokenizer.encode(text)[:self.max_length]
|
||||
for text in self.data["Text"]
|
||||
]
|
||||
# Pad sequences to the longest sequence
|
||||
self.encoded_texts = [
|
||||
et + [pad_token_id] * (self.max_length - len(et))
|
||||
for et in self.encoded_texts
|
||||
]
|
||||
|
||||
if skip_padding:
|
||||
# Pad sequences to the longest sequence
|
||||
self.encoded_texts = [
|
||||
et + [pad_token_id] * (self.max_length - len(et))
|
||||
for et in self.encoded_texts
|
||||
]
|
||||
|
||||
def __getitem__(self, index):
|
||||
encoded = self.encoded_texts[index]
|
||||
@@ -334,6 +336,23 @@ if __name__ == "__main__":
|
||||
"The LoRA alpha value when choosing `--trainable_layers lora`"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_padding",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=(
|
||||
"Enable no padding. When this flag is set it will train"
|
||||
" the model with a batch size of 1 and no padding."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help=(
|
||||
"Number of training epochs."
|
||||
)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -411,26 +430,35 @@ if __name__ == "__main__":
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_dataset = None
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
max_length = train_dataset.max_length
|
||||
|
||||
if args.no_padding:
|
||||
max_length = None
|
||||
|
||||
else:
|
||||
try:
|
||||
max_length = int(args.context_length)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
max_length = train_dataset.max_length
|
||||
else:
|
||||
try:
|
||||
max_length = int(args.context_length)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
|
||||
if train_dataset is None:
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding)
|
||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding)
|
||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
num_workers = 0
|
||||
batch_size = 8
|
||||
|
||||
if args.no_padding:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = 8
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset=train_dataset,
|
||||
@@ -462,10 +490,9 @@ if __name__ == "__main__":
|
||||
torch.manual_seed(123)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
|
||||
|
||||
num_epochs = 5
|
||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=num_epochs, eval_freq=50, eval_iter=5,
|
||||
num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
|
||||
tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user