import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader import torchvision.datasets as datasets from tqdm import tqdm device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_set = datasets.CIFAR10(root="ds/", transform=transforms.ToTensor(), download=True) train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True) def get_mean_std(loader): # var[X] = E[X**2] - E[X]**2 channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0 for data, _ in tqdm(loader): channels_sum += torch.mean(data, dim=[0, 2, 3]) channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3]) num_batches += 1 mean = channels_sum / num_batches std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5 return mean, std mean, std = get_mean_std(train_loader) print(mean) print(std)