Files

51 lines
2.5 KiB
Python
Raw Permalink Normal View History

import pandas as pd
import numpy as np
import config
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
class FacialKeypointDataset(Dataset):
def __init__(self, csv_file, train=True, transform=None):
super().__init__()
self.data = pd.read_csv(csv_file)
self.category_names = ['left_eye_center_x', 'left_eye_center_y', 'right_eye_center_x', 'right_eye_center_y', 'left_eye_inner_corner_x', 'left_eye_inner_corner_y', 'left_eye_outer_corner_x', 'left_eye_outer_corner_y', 'right_eye_inner_corner_x', 'right_eye_inner_corner_y', 'right_eye_outer_corner_x', 'right_eye_outer_corner_y', 'left_eyebrow_inner_end_x', 'left_eyebrow_inner_end_y', 'left_eyebrow_outer_end_x', 'left_eyebrow_outer_end_y', 'right_eyebrow_inner_end_x', 'right_eyebrow_inner_end_y', 'right_eyebrow_outer_end_x', 'right_eyebrow_outer_end_y', 'nose_tip_x', 'nose_tip_y', 'mouth_left_corner_x', 'mouth_left_corner_y', 'mouth_right_corner_x', 'mouth_right_corner_y', 'mouth_center_top_lip_x', 'mouth_center_top_lip_y', 'mouth_center_bottom_lip_x', 'mouth_center_bottom_lip_y']
self.transform = transform
self.train = train
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
if self.train:
image = np.array(self.data.iloc[index, 30].split()).astype(np.float32)
labels = np.array(self.data.iloc[index, :30].tolist())
labels[np.isnan(labels)] = -1
else:
image = np.array(self.data.iloc[index, 1].split()).astype(np.float32)
labels = np.zeros(30)
ignore_indices = labels == -1
labels = labels.reshape(15, 2)
if self.transform:
image = np.repeat(image.reshape(96, 96, 1), 3, 2).astype(np.uint8)
augmentations = self.transform(image=image, keypoints=labels)
image = augmentations["image"]
labels = augmentations["keypoints"]
labels = np.array(labels).reshape(-1)
labels[ignore_indices] = -1
return image, labels.astype(np.float32)
if __name__ == "__main__":
ds = FacialKeypointDataset(csv_file="data/train_4.csv", train=True, transform=config.train_transforms)
loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)
for idx, (x, y) in enumerate(loader):
plt.imshow(x[0][0].detach().cpu().numpy(), cmap='gray')
plt.plot(y[0][0::2].detach().cpu().numpy(), y[0][1::2].detach().cpu().numpy(), "go")
plt.show()