import torch import torchvision.datasets as datasets # Standard datasets from tqdm import tqdm from torch import nn, optim from model import VariationalAutoEncoder from torchvision import transforms from torchvision.utils import save_image from torch.utils.data import DataLoader # Configuration DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") INPUT_DIM = 784 H_DIM = 200 Z_DIM = 20 NUM_EPOCHS = 10 BATCH_SIZE = 32 LR_RATE = 3e-4 # Karpathy constant # Dataset Loading dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True) train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=LR_RATE) loss_fn = nn.BCELoss(reduction="sum") def inference(digit, num_examples=1): """ Generates (num_examples) of a particular digit. Specifically we extract an example of each digit, then after we have the mu, sigma representation for each digit we can sample from that. After we sample we can run the decoder part of the VAE and generate examples. """ images = [] idx = 0 for x, y in dataset: if y == idx: images.append(x) idx += 1 if idx == 10: break encodings_digit = [] for d in range(10): with torch.no_grad(): mu, sigma = model.encode(images[d].view(1, 784)) encodings_digit.append((mu, sigma)) mu, sigma = encodings_digit[digit] for example in range(num_examples): epsilon = torch.randn_like(sigma) z = mu + sigma * epsilon out = model.decode(z) out = out.view(-1, 1, 28, 28) save_image(out, f"generated_{digit}_ex{example}.png") for idx in range(10): inference(idx, num_examples=5)