mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
vae
This commit is contained in:
86
ML/Pytorch/more_advanced/VAE/train.py
Normal file
86
ML/Pytorch/more_advanced/VAE/train.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torchvision.datasets as datasets # Standard datasets
|
||||
from tqdm import tqdm
|
||||
from torch import nn, optim
|
||||
from model import VariationalAutoEncoder
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import save_image
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Configuration
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
INPUT_DIM = 784
|
||||
H_DIM = 200
|
||||
Z_DIM = 20
|
||||
NUM_EPOCHS = 10
|
||||
BATCH_SIZE = 32
|
||||
LR_RATE = 3e-4 # Karpathy constant
|
||||
|
||||
# Dataset Loading
|
||||
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
|
||||
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
|
||||
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
|
||||
loss_fn = nn.BCELoss(reduction="sum")
|
||||
|
||||
# Start Training
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
loop = tqdm(enumerate(train_loader))
|
||||
for i, (x, _) in loop:
|
||||
# Forward pass
|
||||
x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
|
||||
x_reconstructed, mu, sigma = model(x)
|
||||
|
||||
# Compute loss
|
||||
reconstruction_loss = loss_fn(x_reconstructed, x)
|
||||
kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
|
||||
|
||||
# Backprop
|
||||
loss = reconstruction_loss + kl_div
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loop.set_postfix(loss=loss.item())
|
||||
|
||||
|
||||
model = model.to("cpu")
|
||||
def inference(digit, num_examples=1):
|
||||
"""
|
||||
Generates (num_examples) of a particular digit.
|
||||
Specifically we extract an example of each digit,
|
||||
then after we have the mu, sigma representation for
|
||||
each digit we can sample from that.
|
||||
|
||||
After we sample we can run the decoder part of the VAE
|
||||
and generate examples.
|
||||
"""
|
||||
images = []
|
||||
idx = 0
|
||||
for x, y in dataset:
|
||||
if y == idx:
|
||||
images.append(x)
|
||||
idx += 1
|
||||
if idx == 10:
|
||||
break
|
||||
|
||||
encodings_digit = []
|
||||
for d in range(10):
|
||||
with torch.no_grad():
|
||||
mu, sigma = model.encode(images[d].view(1, 784))
|
||||
encodings_digit.append((mu, sigma))
|
||||
|
||||
mu, sigma = encodings_digit[digit]
|
||||
for example in range(num_examples):
|
||||
epsilon = torch.randn_like(sigma)
|
||||
z = mu + sigma * epsilon
|
||||
out = model.decode(z)
|
||||
out = out.view(-1, 1, 28, 28)
|
||||
save_image(out, f"generated_{digit}_ex{example}.png")
|
||||
|
||||
for idx in range(10):
|
||||
inference(idx, num_examples=5)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user