mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
update huggingface code
This commit is contained in:
89
ML/Pytorch/huggingface/dataset.py
Normal file
89
ML/Pytorch/huggingface/dataset.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pandas as pd
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
|
||||
class cnn_dailymail(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=512):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
|
||||
# if the csv_file is "train.csv" then only take out 10% of the data. make sure to reset indices etc
|
||||
#if csv_file == "train.csv":
|
||||
# self.data = self.data.sample(frac=0.05, random_state=42).reset_index(drop=True)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
article = self.data.loc[idx, "article"]
|
||||
highlights = self.data.loc[idx, "highlights"]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
article,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
targets = self.tokenizer(
|
||||
highlights,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": inputs["input_ids"].squeeze(),
|
||||
"attention_mask": inputs["attention_mask"].squeeze(),
|
||||
"labels": targets["input_ids"].squeeze(),
|
||||
}
|
||||
|
||||
|
||||
class MyDataModule(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512
|
||||
):
|
||||
super().__init__()
|
||||
self.train_csv = train_csv
|
||||
self.val_csv = val_csv
|
||||
self.test_csv = test_csv
|
||||
self.tokenizer = tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
|
||||
def setup(self, stage=None):
|
||||
if stage in ("fit", None):
|
||||
self.train_dataset = cnn_dailymail(
|
||||
self.train_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
self.val_dataset = cnn_dailymail(
|
||||
self.val_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
if stage in ("test", None):
|
||||
self.test_dataset = cnn_dailymail(
|
||||
self.test_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
num_workers=6,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1
|
||||
)
|
||||
Reference in New Issue
Block a user