Files
Machine-Learning-Collection/ML/Pytorch/more_advanced/VAE/train.py

62 lines
1.8 KiB
Python

import torch
import torchvision.datasets as datasets # Standard datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4 # Karpathy constant
# Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")
def inference(digit, num_examples=1):
"""
Generates (num_examples) of a particular digit.
Specifically we extract an example of each digit,
then after we have the mu, sigma representation for
each digit we can sample from that.
After we sample we can run the decoder part of the VAE
and generate examples.
"""
images = []
idx = 0
for x, y in dataset:
if y == idx:
images.append(x)
idx += 1
if idx == 10:
break
encodings_digit = []
for d in range(10):
with torch.no_grad():
mu, sigma = model.encode(images[d].view(1, 784))
encodings_digit.append((mu, sigma))
mu, sigma = encodings_digit[digit]
for example in range(num_examples):
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
out = model.decode(z)
out = out.view(-1, 1, 28, 28)
save_image(out, f"generated_{digit}_ex{example}.png")
for idx in range(10):
inference(idx, num_examples=5)