From 371ab9e8ff8801d9af4b102cf0b4282564673ba6 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sat, 5 Apr 2025 10:05:15 -0500 Subject: [PATCH] Correct BERT experiments (#600) --- .../train_bert_hf.py | 439 +++++++++--------- 1 file changed, 222 insertions(+), 217 deletions(-) diff --git a/ch06/03_bonus_imdb-classification/train_bert_hf.py b/ch06/03_bonus_imdb-classification/train_bert_hf.py index 99091d3..11b0c5e 100644 --- a/ch06/03_bonus_imdb-classification/train_bert_hf.py +++ b/ch06/03_bonus_imdb-classification/train_bert_hf.py @@ -8,35 +8,55 @@ from pathlib import Path import time import pandas as pd -import tiktoken import torch from torch.utils.data import DataLoader from torch.utils.data import Dataset -from gpt_download import download_and_load_gpt2 -from previous_chapters import GPTModel, load_weights_into_gpt +from transformers import AutoTokenizer, AutoModelForSequenceClassification class IMDBDataset(Dataset): - def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256): + def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, use_attention_mask=False): self.data = pd.read_csv(csv_file) self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer) + self.pad_token_id = pad_token_id + self.use_attention_mask = use_attention_mask - # Pre-tokenize texts + # Pre-tokenize texts and create attention masks if required self.encoded_texts = [ - tokenizer.encode(text)[:self.max_length] + tokenizer.encode(text, truncation=True, max_length=self.max_length) for text in self.data["text"] ] - # Pad sequences to the longest sequence self.encoded_texts = [ et + [pad_token_id] * (self.max_length - len(et)) for et in self.encoded_texts ] + if self.use_attention_mask: + self.attention_masks = [ + self._create_attention_mask(et) + for et in self.encoded_texts + ] + else: + self.attention_masks = None + + def _create_attention_mask(self, encoded_text): + return [1 if token_id != self.pad_token_id else 0 for token_id in encoded_text] + def __getitem__(self, index): encoded = self.encoded_texts[index] label = self.data.iloc[index]["label"] - return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long) + + if self.use_attention_mask: + attention_mask = self.attention_masks[index] + else: + attention_mask = torch.ones(self.max_length, dtype=torch.long) + + return ( + torch.tensor(encoded, dtype=torch.long), + torch.tensor(attention_mask, dtype=torch.long), + torch.tensor(label, dtype=torch.long) + ) def __len__(self): return len(self.data) @@ -50,71 +70,27 @@ class IMDBDataset(Dataset): return max_length -def instantiate_model(choose_model, load_weights): - - BASE_CONFIG = { - "vocab_size": 50257, # Vocabulary size - "context_length": 1024, # Context length - "drop_rate": 0.0, # Dropout rate - "qkv_bias": True # Query-key-value bias - } - - model_configs = { - "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, - "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, - "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, - "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, - } - - BASE_CONFIG.update(model_configs[choose_model]) - - if not load_weights: - torch.manual_seed(123) - model = GPTModel(BASE_CONFIG) - - if load_weights: - model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")") - settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") - load_weights_into_gpt(model, params) - - model.eval() - return model - - -def calc_loss_batch(input_batch, target_batch, model, device, - trainable_token_pos=-1, average_embeddings=False): +def calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device): + attention_mask_batch = attention_mask_batch.to(device) input_batch, target_batch = input_batch.to(device), target_batch.to(device) - - model_output = model(input_batch) - if average_embeddings: - # Average over the sequence dimension (dim=1) - logits = model_output.mean(dim=1) - else: - # Select embeddings at the specified token position - logits = model_output[:, trainable_token_pos, :] - + # logits = model(input_batch)[:, -1, :] # Logits of last output token + logits = model(input_batch, attention_mask=attention_mask_batch).logits loss = torch.nn.functional.cross_entropy(logits, target_batch) return loss -def calc_loss_loader(data_loader, model, device, - num_batches=None, trainable_token_pos=-1, - average_embeddings=False): +# Same as in chapter 5 +def calc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. - if len(data_loader) == 0: - return float("nan") - elif num_batches is None: + if num_batches is None: num_batches = len(data_loader) else: # Reduce the number of batches to match the total number of batches in the data loader # if num_batches exceeds the number of batches in the data loader num_batches = min(num_batches, len(data_loader)) - for i, (input_batch, target_batch) in enumerate(data_loader): + for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader): if i < num_batches: - loss = calc_loss_batch( - input_batch, target_batch, model, device, - trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings - ) + loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device) total_loss += loss.item() else: break @@ -122,9 +98,7 @@ def calc_loss_loader(data_loader, model, device, @torch.no_grad() # Disable gradient tracking for efficiency -def calc_accuracy_loader(data_loader, model, device, - num_batches=None, trainable_token_pos=-1, - average_embeddings=False): +def calc_accuracy_loader(data_loader, model, device, num_batches=None): model.eval() correct_predictions, num_examples = 0, 0 @@ -132,20 +106,13 @@ def calc_accuracy_loader(data_loader, model, device, num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) - for i, (input_batch, target_batch) in enumerate(data_loader): + for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader): if i < num_batches: + attention_mask_batch = attention_mask_batch.to(device) input_batch, target_batch = input_batch.to(device), target_batch.to(device) - - model_output = model(input_batch) - if average_embeddings: - # Average over the sequence dimension (dim=1) - logits = model_output.mean(dim=1) - else: - # Select embeddings at the specified token position - logits = model_output[:, trainable_token_pos, :] - - predicted_labels = torch.argmax(logits, dim=-1) - + # logits = model(input_batch)[:, -1, :] # Logits of last output token + logits = model(input_batch, attention_mask=attention_mask_batch).logits + predicted_labels = torch.argmax(logits, dim=1) num_examples += predicted_labels.shape[0] correct_predictions += (predicted_labels == target_batch).sum().item() else: @@ -153,25 +120,17 @@ def calc_accuracy_loader(data_loader, model, device, return correct_predictions / num_examples -def evaluate_model(model, train_loader, val_loader, device, eval_iter, - trainable_token_pos=-1, average_embeddings=False): +def evaluate_model(model, train_loader, val_loader, device, eval_iter): model.eval() with torch.no_grad(): - train_loss = calc_loss_loader( - train_loader, model, device, num_batches=eval_iter, - trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings - ) - val_loss = calc_loss_loader( - val_loader, model, device, num_batches=eval_iter, - trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings - ) + train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) + val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) model.train() return train_loss, val_loss def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, - eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1, - average_embeddings=False): + eval_freq, eval_iter, max_steps=None): # Initialize lists to track losses and tokens seen train_losses, val_losses, train_accs, val_accs = [], [], [], [] examples_seen, global_step = 0, -1 @@ -180,10 +139,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, for epoch in range(num_epochs): model.train() # Set model to training mode - for input_batch, target_batch in train_loader: + for input_batch, attention_mask_batch, target_batch in train_loader: optimizer.zero_grad() # Reset loss gradients from previous batch iteration - loss = calc_loss_batch(input_batch, target_batch, model, device, - trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings) + loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device) loss.backward() # Calculate loss gradients optimizer.step() # Update model weights using loss gradients examples_seen += input_batch.shape[0] # New: track examples instead of tokens @@ -192,9 +150,7 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, # Optional evaluation step if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( - model, train_loader, val_loader, device, eval_iter, - trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings - ) + model, train_loader, val_loader, device, eval_iter) train_losses.append(train_loss) val_losses.append(val_loss) print(f"Ep {epoch+1} (Step {global_step:06d}): " @@ -204,14 +160,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, break # New: Calculate accuracy after each epoch - train_accuracy = calc_accuracy_loader( - train_loader, model, device, num_batches=eval_iter, - trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings - ) - val_accuracy = calc_accuracy_loader( - val_loader, model, device, num_batches=eval_iter, - trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings - ) + train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter) + val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter) print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") print(f"Validation accuracy: {val_accuracy*100:.2f}%") train_accs.append(train_accuracy) @@ -226,55 +176,28 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--model_size", - type=str, - default="gpt2-small (124M)", - help=( - "Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)'," - " 'gpt2-large (774M)', 'gpt2-xl (1558M)'." - ) - ) - parser.add_argument( - "--weights", - type=str, - default="pretrained", - help=( - "Whether to use 'pretrained' or 'random' weights." - ) - ) parser.add_argument( "--trainable_layers", type=str, - default="last_block", + default="all", help=( "Which layers to train. Options: 'all', 'last_block', 'last_layer'." ) ) parser.add_argument( - "--trainable_token_pos", + "--use_attention_mask", type=str, - default="last", + default="true", help=( - "Which token to train. Options: 'first', 'last'." + "Whether to use a attention mask for padding tokens. Options: 'true', 'false'." ) ) parser.add_argument( - "--average_embeddings", - action='store_true', - default=False, - help=( - "Average the output embeddings from all tokens instead of using" - " only the embedding at the token position specified by `--trainable_token_pos`." - ) - ) - parser.add_argument( - "--context_length", + "--model", type=str, - default="256", + default="distilbert", help=( - "The context length of the data inputs." - "Options: 'longest_training_example', 'model_context_length' or integer value." + "Which model to train. Options: 'distilbert', 'bert', 'roberta', 'modern-bert-base', 'modern-bert-large." ) ) parser.add_argument( @@ -288,73 +211,155 @@ if __name__ == "__main__": parser.add_argument( "--learning_rate", type=float, - default=5e-5, + default=5e-6, help=( "Learning rate." ) ) - parser.add_argument( - "--compile", - action="store_true", - help="If set, model compilation will be enabled." - ) args = parser.parse_args() - if args.trainable_token_pos == "first": - args.trainable_token_pos = 0 - elif args.trainable_token_pos == "last": - args.trainable_token_pos = -1 - else: - raise ValueError("Invalid --trainable_token_pos argument") - ############################### # Load model ############################### - if args.weights == "pretrained": - load_weights = True - elif args.weights == "random": - load_weights = False - else: - raise ValueError("Invalid --weights argument.") - - model = instantiate_model(args.model_size, load_weights) - for param in model.parameters(): - param.requires_grad = False - - if args.model_size == "gpt2-small (124M)": - in_features = 768 - elif args.model_size == "gpt2-medium (355M)": - in_features = 1024 - elif args.model_size == "gpt2-large (774M)": - in_features = 1280 - elif args.model_size == "gpt2-xl (1558M)": - in_features = 1600 - else: - raise ValueError("Invalid --model_size argument") - torch.manual_seed(123) - model.out_head = torch.nn.Linear(in_features=in_features, out_features=2) + if args.model == "distilbert": - if args.trainable_layers == "last_layer": - pass - elif args.trainable_layers == "last_block": - for param in model.trf_blocks[-1].parameters(): - param.requires_grad = True - for param in model.final_norm.parameters(): - param.requires_grad = True - elif args.trainable_layers == "all": + model = AutoModelForSequenceClassification.from_pretrained( + "distilbert-base-uncased", num_labels=2 + ) + model.out_head = torch.nn.Linear(in_features=768, out_features=2) for param in model.parameters(): - param.requires_grad = True + param.requires_grad = False + if args.trainable_layers == "last_layer": + for param in model.out_head.parameters(): + param.requires_grad = True + elif args.trainable_layers == "last_block": + for param in model.pre_classifier.parameters(): + param.requires_grad = True + for param in model.distilbert.transformer.layer[-1].parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + + elif args.model == "bert": + + model = AutoModelForSequenceClassification.from_pretrained( + "bert-base-uncased", num_labels=2 + ) + model.classifier = torch.nn.Linear(in_features=768, out_features=2) + for param in model.parameters(): + param.requires_grad = False + if args.trainable_layers == "last_layer": + for param in model.classifier.parameters(): + param.requires_grad = True + elif args.trainable_layers == "last_block": + for param in model.classifier.parameters(): + param.requires_grad = True + for param in model.bert.pooler.dense.parameters(): + param.requires_grad = True + for param in model.bert.encoder.layer[-1].parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + elif args.model == "roberta": + + model = AutoModelForSequenceClassification.from_pretrained( + "FacebookAI/roberta-large", num_labels=2 + ) + model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2) + for param in model.parameters(): + param.requires_grad = False + if args.trainable_layers == "last_layer": + for param in model.classifier.parameters(): + param.requires_grad = True + elif args.trainable_layers == "last_block": + for param in model.classifier.parameters(): + param.requires_grad = True + for param in model.roberta.encoder.layer[-1].parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large") + + elif args.model in ("modern-bert-base", "modern-bert-large"): + + if args.model == "modern-bert-base": + model = AutoModelForSequenceClassification.from_pretrained( + "answerdotai/ModernBERT-base", num_labels=2 + ) + model.classifier = torch.nn.Linear(in_features=768, out_features=2) + else: + model = AutoModelForSequenceClassification.from_pretrained( + "answerdotai/ModernBERT-large", num_labels=2 + ) + model.classifier = torch.nn.Linear(in_features=1024, out_features=2) + for param in model.parameters(): + param.requires_grad = False + if args.trainable_layers == "last_layer": + for param in model.classifier.parameters(): + param.requires_grad = True + elif args.trainable_layers == "last_block": + for param in model.classifier.parameters(): + param.requires_grad = True + for param in model.model.layers[-1].parameters(): + param.requires_grad = True + for param in model.head.parameters(): + param.requires_grad = True + for param in model.classifier.parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") + + elif args.model == "modern-bert-base": + model = AutoModelForSequenceClassification.from_pretrained( + "answerdotai/ModernBERT-base", num_labels=2 + ) + print(model) + model.classifier = torch.nn.Linear(in_features=768, out_features=2) + for param in model.parameters(): + param.requires_grad = False + if args.trainable_layers == "last_layer": + for param in model.classifier.parameters(): + param.requires_grad = True + elif args.trainable_layers == "last_block": + for param in model.classifier.parameters(): + param.requires_grad = True + for param in model.layers.layer[-1].parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") + else: - raise ValueError("Invalid --trainable_layers argument.") + raise ValueError("Selected --model {args.model} not supported.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) - - if args.compile: - torch.set_float32_matmul_precision("high") - model = torch.compile(model) + model.eval() ############################### # Instantiate dataloaders @@ -362,24 +367,34 @@ if __name__ == "__main__": base_path = Path(".") - tokenizer = tiktoken.get_encoding("gpt2") - - train_dataset = None - if args.context_length == "model_context_length": - max_length = model.pos_emb.weight.shape[0] - elif args.context_length == "longest_training_example": - train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer) - max_length = train_dataset.max_length + if args.use_attention_mask.lower() == "true": + use_attention_mask = True + elif args.use_attention_mask.lower() == "false": + use_attention_mask = False else: - try: - max_length = int(args.context_length) - except ValueError: - raise ValueError("Invalid --context_length argument") + raise ValueError("Invalid argument for `use_attention_mask`.") - if train_dataset is None: - train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer) - val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer) - test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer) + train_dataset = IMDBDataset( + base_path / "train.csv", + max_length=256, + tokenizer=tokenizer, + pad_token_id=tokenizer.pad_token_id, + use_attention_mask=use_attention_mask + ) + val_dataset = IMDBDataset( + base_path / "validation.csv", + max_length=256, + tokenizer=tokenizer, + pad_token_id=tokenizer.pad_token_id, + use_attention_mask=use_attention_mask + ) + test_dataset = IMDBDataset( + base_path / "test.csv", + max_length=256, + tokenizer=tokenizer, + pad_token_id=tokenizer.pad_token_id, + use_attention_mask=use_attention_mask + ) num_workers = 0 batch_size = 8 @@ -417,8 +432,7 @@ if __name__ == "__main__": train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( model, train_loader, val_loader, optimizer, device, num_epochs=args.num_epochs, eval_freq=50, eval_iter=20, - max_steps=None, trainable_token_pos=args.trainable_token_pos, - average_embeddings=args.average_embeddings + max_steps=None ) end_time = time.time() @@ -431,19 +445,10 @@ if __name__ == "__main__": print("\nEvaluating on the full datasets ...\n") - train_accuracy = calc_accuracy_loader( - train_loader, model, device, - trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings - ) - val_accuracy = calc_accuracy_loader( - val_loader, model, device, - trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings - ) - test_accuracy = calc_accuracy_loader( - test_loader, model, device, - trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings - ) + train_accuracy = calc_accuracy_loader(train_loader, model, device) + val_accuracy = calc_accuracy_loader(val_loader, model, device) + test_accuracy = calc_accuracy_loader(test_loader, model, device) print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%") - print(f"Test accuracy: {test_accuracy*100:.2f}%") + print(f"Test accuracy: {test_accuracy*100:.2f}%") \ No newline at end of file