This commit is contained in:
dino
2022-09-13 13:04:49 +02:00
parent ac5dcd03a4
commit ae581a64e6
3 changed files with 140 additions and 9 deletions

View File

@@ -0,0 +1,45 @@
import torch
from torch import nn
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_dim, h_dim=200, z_dim=20):
super().__init__()
# encoder
self.img_2hid = nn.Linear(input_dim, h_dim)
self.hid_2mu = nn.Linear(h_dim, z_dim)
self.hid_2sigma = nn.Linear(h_dim, z_dim)
# decoder
self.z_2hid = nn.Linear(z_dim, h_dim)
self.hid_2img = nn.Linear(h_dim, input_dim)
self.relu = nn.ReLU()
def encode(self, x):
h = self.relu(self.img_2hid(x))
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
return mu, sigma
def decode(self, z):
h = self.relu(self.z_2hid(z))
return torch.sigmoid(self.hid_2img(h))
def forward(self, x):
mu, sigma = self.encode(x)
epsilon = torch.randn_like(sigma)
z_new = mu + sigma*epsilon
x_reconstructed = self.decode(z_new)
return x_reconstructed, mu, sigma
if __name__ == "__main__":
x = torch.randn(4, 28*28)
vae = VariationalAutoEncoder(input_dim=784)
x_reconstructed, mu, sigma = vae(x)
print(x_reconstructed.shape)
print(mu.shape)
print(sigma.shape)

View File

@@ -0,0 +1,86 @@
import torch
import torchvision.datasets as datasets # Standard datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4 # Karpathy constant
# Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")
# Start Training
for epoch in range(NUM_EPOCHS):
loop = tqdm(enumerate(train_loader))
for i, (x, _) in loop:
# Forward pass
x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
x_reconstructed, mu, sigma = model(x)
# Compute loss
reconstruction_loss = loss_fn(x_reconstructed, x)
kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
# Backprop
loss = reconstruction_loss + kl_div
optimizer.zero_grad()
loss.backward()
optimizer.step()
loop.set_postfix(loss=loss.item())
model = model.to("cpu")
def inference(digit, num_examples=1):
"""
Generates (num_examples) of a particular digit.
Specifically we extract an example of each digit,
then after we have the mu, sigma representation for
each digit we can sample from that.
After we sample we can run the decoder part of the VAE
and generate examples.
"""
images = []
idx = 0
for x, y in dataset:
if y == idx:
images.append(x)
idx += 1
if idx == 10:
break
encodings_digit = []
for d in range(10):
with torch.no_grad():
mu, sigma = model.encode(images[d].view(1, 784))
encodings_digit.append((mu, sigma))
mu, sigma = encodings_digit[digit]
for example in range(num_examples):
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
out = model.decode(z)
out = out.view(-1, 1, 28, 28)
save_image(out, f"generated_{digit}_ex{example}.png")
for idx in range(10):
inference(idx, num_examples=5)

View File

@@ -23,10 +23,10 @@ class SelfAttention(nn.Module):
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
self.values = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.queries = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, values, keys, query, mask):
# Get number of training examples
@@ -34,14 +34,14 @@ class SelfAttention(nn.Module):
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
values = self.values(values) # (N, value_len, embed_size)
keys = self.keys(keys) # (N, key_len, embed_size)
queries = self.queries(query) # (N, query_len, embed_size)
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
query = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values) # (N, value_len, heads, head_dim)
keys = self.keys(keys) # (N, key_len, heads, head_dim)
queries = self.queries(query) # (N, query_len, heads, heads_dim)
queries = queries.reshape(N, query_len, self.heads, self.head_dim)
# Einsum does matrix mult. for query*keys for each training example
# with every other training example, don't be confused by einsum