import pandas as pd import numpy as np import config import matplotlib.pyplot as plt from torch.utils.data import DataLoader, Dataset class FacialKeypointDataset(Dataset): def __init__(self, csv_file, train=True, transform=None): super().__init__() self.data = pd.read_csv(csv_file) self.category_names = ['left_eye_center_x', 'left_eye_center_y', 'right_eye_center_x', 'right_eye_center_y', 'left_eye_inner_corner_x', 'left_eye_inner_corner_y', 'left_eye_outer_corner_x', 'left_eye_outer_corner_y', 'right_eye_inner_corner_x', 'right_eye_inner_corner_y', 'right_eye_outer_corner_x', 'right_eye_outer_corner_y', 'left_eyebrow_inner_end_x', 'left_eyebrow_inner_end_y', 'left_eyebrow_outer_end_x', 'left_eyebrow_outer_end_y', 'right_eyebrow_inner_end_x', 'right_eyebrow_inner_end_y', 'right_eyebrow_outer_end_x', 'right_eyebrow_outer_end_y', 'nose_tip_x', 'nose_tip_y', 'mouth_left_corner_x', 'mouth_left_corner_y', 'mouth_right_corner_x', 'mouth_right_corner_y', 'mouth_center_top_lip_x', 'mouth_center_top_lip_y', 'mouth_center_bottom_lip_x', 'mouth_center_bottom_lip_y'] self.transform = transform self.train = train def __len__(self): return self.data.shape[0] def __getitem__(self, index): if self.train: image = np.array(self.data.iloc[index, 30].split()).astype(np.float32) labels = np.array(self.data.iloc[index, :30].tolist()) labels[np.isnan(labels)] = -1 else: image = np.array(self.data.iloc[index, 1].split()).astype(np.float32) labels = np.zeros(30) ignore_indices = labels == -1 labels = labels.reshape(15, 2) if self.transform: image = np.repeat(image.reshape(96, 96, 1), 3, 2).astype(np.uint8) augmentations = self.transform(image=image, keypoints=labels) image = augmentations["image"] labels = augmentations["keypoints"] labels = np.array(labels).reshape(-1) labels[ignore_indices] = -1 return image, labels.astype(np.float32) if __name__ == "__main__": ds = FacialKeypointDataset(csv_file="data/train_4.csv", train=True, transform=config.train_transforms) loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0) for idx, (x, y) in enumerate(loader): plt.imshow(x[0][0].detach().cpu().numpy(), cmap='gray') plt.plot(y[0][0::2].detach().cpu().numpy(), y[0][1::2].detach().cpu().numpy(), "go") plt.show()