mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
from transformers import Trainer
|
|
|
|
raw_datasets = load_dataset("glue", "mrpc")
|
|
checkpoint = "bert-base-uncased"
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
|
|
|
|
|
def tokenize_function(example):
|
|
return tokenizer(example["sentence1"], example["sentence2"], truncation=True)
|
|
|
|
|
|
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
|
|
|
|
from transformers import TrainingArguments
|
|
training_args = TrainingArguments("test-trainer")
|
|
|
|
from transformers import AutoModelForSequenceClassification
|
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
|
|
|
|
def compute_metrics(eval_preds):
|
|
metric = evaluate.load("glue", "mrpc")
|
|
logits, labels = eval_preds
|
|
predictions = np.argmax(logits, axis=-1)
|
|
return metric.compute(predictions=predictions, references=labels)
|
|
|
|
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
|
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
|
|
|
|
trainer = Trainer(
|
|
model,
|
|
training_args,
|
|
train_dataset=tokenized_datasets["train"],
|
|
eval_dataset=tokenized_datasets["validation"],
|
|
data_collator=data_collator,
|
|
tokenizer=tokenizer,
|
|
compute_metrics=compute_metrics,
|
|
)
|