mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
import torch
|
|
import torchvision.datasets as datasets # Standard datasets
|
|
from tqdm import tqdm
|
|
from torch import nn, optim
|
|
from model import VariationalAutoEncoder
|
|
from torchvision import transforms
|
|
from torchvision.utils import save_image
|
|
from torch.utils.data import DataLoader
|
|
|
|
# Configuration
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
INPUT_DIM = 784
|
|
H_DIM = 200
|
|
Z_DIM = 20
|
|
NUM_EPOCHS = 10
|
|
BATCH_SIZE = 32
|
|
LR_RATE = 3e-4 # Karpathy constant
|
|
|
|
# Dataset Loading
|
|
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
|
|
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
|
|
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
|
|
loss_fn = nn.BCELoss(reduction="sum")
|
|
|
|
|
|
def inference(digit, num_examples=1):
|
|
"""
|
|
Generates (num_examples) of a particular digit.
|
|
Specifically we extract an example of each digit,
|
|
then after we have the mu, sigma representation for
|
|
each digit we can sample from that.
|
|
|
|
After we sample we can run the decoder part of the VAE
|
|
and generate examples.
|
|
"""
|
|
images = []
|
|
idx = 0
|
|
for x, y in dataset:
|
|
if y == idx:
|
|
images.append(x)
|
|
idx += 1
|
|
if idx == 10:
|
|
break
|
|
|
|
encodings_digit = []
|
|
for d in range(10):
|
|
with torch.no_grad():
|
|
mu, sigma = model.encode(images[d].view(1, 784))
|
|
encodings_digit.append((mu, sigma))
|
|
|
|
mu, sigma = encodings_digit[digit]
|
|
for example in range(num_examples):
|
|
epsilon = torch.randn_like(sigma)
|
|
z = mu + sigma * epsilon
|
|
out = model.decode(z)
|
|
out = out.view(-1, 1, 28, 28)
|
|
save_image(out, f"generated_{digit}_ex{example}.png")
|
|
|
|
for idx in range(10):
|
|
inference(idx, num_examples=5)
|