mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Correct BERT experiments (#600)
This commit is contained in:
committed by
GitHub
parent
14f976e024
commit
ab17357474
@@ -8,35 +8,55 @@ from pathlib import Path
|
||||
import time
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from gpt_download import download_and_load_gpt2
|
||||
from previous_chapters import GPTModel, load_weights_into_gpt
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
|
||||
class IMDBDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, use_attention_mask=False):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
|
||||
self.pad_token_id = pad_token_id
|
||||
self.use_attention_mask = use_attention_mask
|
||||
|
||||
# Pre-tokenize texts
|
||||
# Pre-tokenize texts and create attention masks if required
|
||||
self.encoded_texts = [
|
||||
tokenizer.encode(text)[:self.max_length]
|
||||
tokenizer.encode(text, truncation=True, max_length=self.max_length)
|
||||
for text in self.data["text"]
|
||||
]
|
||||
# Pad sequences to the longest sequence
|
||||
self.encoded_texts = [
|
||||
et + [pad_token_id] * (self.max_length - len(et))
|
||||
for et in self.encoded_texts
|
||||
]
|
||||
|
||||
if self.use_attention_mask:
|
||||
self.attention_masks = [
|
||||
self._create_attention_mask(et)
|
||||
for et in self.encoded_texts
|
||||
]
|
||||
else:
|
||||
self.attention_masks = None
|
||||
|
||||
def _create_attention_mask(self, encoded_text):
|
||||
return [1 if token_id != self.pad_token_id else 0 for token_id in encoded_text]
|
||||
|
||||
def __getitem__(self, index):
|
||||
encoded = self.encoded_texts[index]
|
||||
label = self.data.iloc[index]["label"]
|
||||
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
|
||||
|
||||
if self.use_attention_mask:
|
||||
attention_mask = self.attention_masks[index]
|
||||
else:
|
||||
attention_mask = torch.ones(self.max_length, dtype=torch.long)
|
||||
|
||||
return (
|
||||
torch.tensor(encoded, dtype=torch.long),
|
||||
torch.tensor(attention_mask, dtype=torch.long),
|
||||
torch.tensor(label, dtype=torch.long)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@@ -50,71 +70,27 @@ class IMDBDataset(Dataset):
|
||||
return max_length
|
||||
|
||||
|
||||
def instantiate_model(choose_model, load_weights):
|
||||
|
||||
BASE_CONFIG = {
|
||||
"vocab_size": 50257, # Vocabulary size
|
||||
"context_length": 1024, # Context length
|
||||
"drop_rate": 0.0, # Dropout rate
|
||||
"qkv_bias": True # Query-key-value bias
|
||||
}
|
||||
|
||||
model_configs = {
|
||||
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
||||
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
||||
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
||||
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
||||
}
|
||||
|
||||
BASE_CONFIG.update(model_configs[choose_model])
|
||||
|
||||
if not load_weights:
|
||||
torch.manual_seed(123)
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
|
||||
if load_weights:
|
||||
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
load_weights_into_gpt(model, params)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=-1, average_embeddings=False):
|
||||
def calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device):
|
||||
attention_mask_batch = attention_mask_batch.to(device)
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
# logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||||
logits = model(input_batch, attention_mask=attention_mask_batch).logits
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
||||
return loss
|
||||
|
||||
|
||||
def calc_loss_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
# Same as in chapter 5
|
||||
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||
total_loss = 0.
|
||||
if len(data_loader) == 0:
|
||||
return float("nan")
|
||||
elif num_batches is None:
|
||||
if num_batches is None:
|
||||
num_batches = len(data_loader)
|
||||
else:
|
||||
# Reduce the number of batches to match the total number of batches in the data loader
|
||||
# if num_batches exceeds the number of batches in the data loader
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
loss = calc_loss_batch(
|
||||
input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device)
|
||||
total_loss += loss.item()
|
||||
else:
|
||||
break
|
||||
@@ -122,9 +98,7 @@ def calc_loss_loader(data_loader, model, device,
|
||||
|
||||
|
||||
@torch.no_grad() # Disable gradient tracking for efficiency
|
||||
def calc_accuracy_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||||
model.eval()
|
||||
correct_predictions, num_examples = 0, 0
|
||||
|
||||
@@ -132,20 +106,13 @@ def calc_accuracy_loader(data_loader, model, device,
|
||||
num_batches = len(data_loader)
|
||||
else:
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
attention_mask_batch = attention_mask_batch.to(device)
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
|
||||
# logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||||
logits = model(input_batch, attention_mask=attention_mask_batch).logits
|
||||
predicted_labels = torch.argmax(logits, dim=1)
|
||||
num_examples += predicted_labels.shape[0]
|
||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||
else:
|
||||
@@ -153,25 +120,17 @@ def calc_accuracy_loader(data_loader, model, device,
|
||||
return correct_predictions / num_examples
|
||||
|
||||
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token_pos=-1, average_embeddings=False):
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
train_loss = calc_loss_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
val_loss = calc_loss_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
|
||||
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
|
||||
model.train()
|
||||
return train_loss, val_loss
|
||||
|
||||
|
||||
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
eval_freq, eval_iter, max_steps=None):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, train_accs, val_accs = [], [], [], []
|
||||
examples_seen, global_step = 0, -1
|
||||
@@ -180,10 +139,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
for epoch in range(num_epochs):
|
||||
model.train() # Set model to training mode
|
||||
|
||||
for input_batch, target_batch in train_loader:
|
||||
for input_batch, attention_mask_batch, target_batch in train_loader:
|
||||
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings)
|
||||
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device)
|
||||
loss.backward() # Calculate loss gradients
|
||||
optimizer.step() # Update model weights using loss gradients
|
||||
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
|
||||
@@ -192,9 +150,7 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
# Optional evaluation step
|
||||
if global_step % eval_freq == 0:
|
||||
train_loss, val_loss = evaluate_model(
|
||||
model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
model, train_loader, val_loader, device, eval_iter)
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
print(f"Ep {epoch+1} (Step {global_step:06d}): "
|
||||
@@ -204,14 +160,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
break
|
||||
|
||||
# New: Calculate accuracy after each epoch
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
train_accs.append(train_accuracy)
|
||||
@@ -226,55 +176,28 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="gpt2-small (124M)",
|
||||
help=(
|
||||
"Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
|
||||
" 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weights",
|
||||
type=str,
|
||||
default="pretrained",
|
||||
help=(
|
||||
"Whether to use 'pretrained' or 'random' weights."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainable_layers",
|
||||
type=str,
|
||||
default="last_block",
|
||||
default="all",
|
||||
help=(
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainable_token_pos",
|
||||
"--use_attention_mask",
|
||||
type=str,
|
||||
default="last",
|
||||
default="true",
|
||||
help=(
|
||||
"Which token to train. Options: 'first', 'last'."
|
||||
"Whether to use a attention mask for padding tokens. Options: 'true', 'false'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--average_embeddings",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=(
|
||||
"Average the output embeddings from all tokens instead of using"
|
||||
" only the embedding at the token position specified by `--trainable_token_pos`."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_length",
|
||||
"--model",
|
||||
type=str,
|
||||
default="256",
|
||||
default="distilbert",
|
||||
help=(
|
||||
"The context length of the data inputs."
|
||||
"Options: 'longest_training_example', 'model_context_length' or integer value."
|
||||
"Which model to train. Options: 'distilbert', 'bert', 'roberta', 'modern-bert-base', 'modern-bert-large."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -288,73 +211,155 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
default=5e-6,
|
||||
help=(
|
||||
"Learning rate."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
action="store_true",
|
||||
help="If set, model compilation will be enabled."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token_pos == "first":
|
||||
args.trainable_token_pos = 0
|
||||
elif args.trainable_token_pos == "last":
|
||||
args.trainable_token_pos = -1
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_token_pos argument")
|
||||
|
||||
###############################
|
||||
# Load model
|
||||
###############################
|
||||
|
||||
if args.weights == "pretrained":
|
||||
load_weights = True
|
||||
elif args.weights == "random":
|
||||
load_weights = False
|
||||
else:
|
||||
raise ValueError("Invalid --weights argument.")
|
||||
|
||||
model = instantiate_model(args.model_size, load_weights)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if args.model_size == "gpt2-small (124M)":
|
||||
in_features = 768
|
||||
elif args.model_size == "gpt2-medium (355M)":
|
||||
in_features = 1024
|
||||
elif args.model_size == "gpt2-large (774M)":
|
||||
in_features = 1280
|
||||
elif args.model_size == "gpt2-xl (1558M)":
|
||||
in_features = 1600
|
||||
else:
|
||||
raise ValueError("Invalid --model_size argument")
|
||||
|
||||
torch.manual_seed(123)
|
||||
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
|
||||
if args.model == "distilbert":
|
||||
|
||||
if args.trainable_layers == "last_layer":
|
||||
pass
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.trf_blocks[-1].parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.final_norm.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"distilbert-base-uncased", num_labels=2
|
||||
)
|
||||
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.out_head.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.pre_classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.distilbert.transformer.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
elif args.model == "bert":
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased", num_labels=2
|
||||
)
|
||||
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.bert.pooler.dense.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.bert.encoder.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
elif args.model == "roberta":
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"FacebookAI/roberta-large", num_labels=2
|
||||
)
|
||||
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.roberta.encoder.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
|
||||
|
||||
elif args.model in ("modern-bert-base", "modern-bert-large"):
|
||||
|
||||
if args.model == "modern-bert-base":
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"answerdotai/ModernBERT-base", num_labels=2
|
||||
)
|
||||
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
||||
else:
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"answerdotai/ModernBERT-large", num_labels=2
|
||||
)
|
||||
model.classifier = torch.nn.Linear(in_features=1024, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.model.layers[-1].parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.head.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
||||
|
||||
elif args.model == "modern-bert-base":
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"answerdotai/ModernBERT-base", num_labels=2
|
||||
)
|
||||
print(model)
|
||||
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.layers.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
raise ValueError("Selected --model {args.model} not supported.")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
if args.compile:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
model = torch.compile(model)
|
||||
model.eval()
|
||||
|
||||
###############################
|
||||
# Instantiate dataloaders
|
||||
@@ -362,24 +367,34 @@ if __name__ == "__main__":
|
||||
|
||||
base_path = Path(".")
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_dataset = None
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
max_length = train_dataset.max_length
|
||||
if args.use_attention_mask.lower() == "true":
|
||||
use_attention_mask = True
|
||||
elif args.use_attention_mask.lower() == "false":
|
||||
use_attention_mask = False
|
||||
else:
|
||||
try:
|
||||
max_length = int(args.context_length)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
raise ValueError("Invalid argument for `use_attention_mask`.")
|
||||
|
||||
if train_dataset is None:
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
train_dataset = IMDBDataset(
|
||||
base_path / "train.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
val_dataset = IMDBDataset(
|
||||
base_path / "validation.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
test_dataset = IMDBDataset(
|
||||
base_path / "test.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
|
||||
num_workers = 0
|
||||
batch_size = 8
|
||||
@@ -417,8 +432,7 @@ if __name__ == "__main__":
|
||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=args.num_epochs, eval_freq=50, eval_iter=20,
|
||||
max_steps=None, trainable_token_pos=args.trainable_token_pos,
|
||||
average_embeddings=args.average_embeddings
|
||||
max_steps=None
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@@ -431,19 +445,10 @@ if __name__ == "__main__":
|
||||
|
||||
print("\nEvaluating on the full datasets ...\n")
|
||||
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
test_accuracy = calc_accuracy_loader(
|
||||
test_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device)
|
||||
test_accuracy = calc_accuracy_loader(test_loader, model, device)
|
||||
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}%")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
print(f"Test accuracy: {test_accuracy*100:.2f}%")
|
||||
print(f"Test accuracy: {test_accuracy*100:.2f}%")
|
||||
Reference in New Issue
Block a user