mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
vae
This commit is contained in:
45
ML/Pytorch/more_advanced/VAE/model.py
Normal file
45
ML/Pytorch/more_advanced/VAE/model.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class VariationalAutoEncoder(nn.Module):
|
||||||
|
def __init__(self, input_dim, h_dim=200, z_dim=20):
|
||||||
|
super().__init__()
|
||||||
|
# encoder
|
||||||
|
self.img_2hid = nn.Linear(input_dim, h_dim)
|
||||||
|
self.hid_2mu = nn.Linear(h_dim, z_dim)
|
||||||
|
self.hid_2sigma = nn.Linear(h_dim, z_dim)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.z_2hid = nn.Linear(z_dim, h_dim)
|
||||||
|
self.hid_2img = nn.Linear(h_dim, input_dim)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.relu(self.img_2hid(x))
|
||||||
|
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
h = self.relu(self.z_2hid(z))
|
||||||
|
return torch.sigmoid(self.hid_2img(h))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mu, sigma = self.encode(x)
|
||||||
|
epsilon = torch.randn_like(sigma)
|
||||||
|
z_new = mu + sigma*epsilon
|
||||||
|
x_reconstructed = self.decode(z_new)
|
||||||
|
return x_reconstructed, mu, sigma
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
x = torch.randn(4, 28*28)
|
||||||
|
vae = VariationalAutoEncoder(input_dim=784)
|
||||||
|
x_reconstructed, mu, sigma = vae(x)
|
||||||
|
print(x_reconstructed.shape)
|
||||||
|
print(mu.shape)
|
||||||
|
print(sigma.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
86
ML/Pytorch/more_advanced/VAE/train.py
Normal file
86
ML/Pytorch/more_advanced/VAE/train.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision.datasets as datasets # Standard datasets
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch import nn, optim
|
||||||
|
from model import VariationalAutoEncoder
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
INPUT_DIM = 784
|
||||||
|
H_DIM = 200
|
||||||
|
Z_DIM = 20
|
||||||
|
NUM_EPOCHS = 10
|
||||||
|
BATCH_SIZE = 32
|
||||||
|
LR_RATE = 3e-4 # Karpathy constant
|
||||||
|
|
||||||
|
# Dataset Loading
|
||||||
|
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
|
||||||
|
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
|
||||||
|
loss_fn = nn.BCELoss(reduction="sum")
|
||||||
|
|
||||||
|
# Start Training
|
||||||
|
for epoch in range(NUM_EPOCHS):
|
||||||
|
loop = tqdm(enumerate(train_loader))
|
||||||
|
for i, (x, _) in loop:
|
||||||
|
# Forward pass
|
||||||
|
x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
|
||||||
|
x_reconstructed, mu, sigma = model(x)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
reconstruction_loss = loss_fn(x_reconstructed, x)
|
||||||
|
kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
|
||||||
|
|
||||||
|
# Backprop
|
||||||
|
loss = reconstruction_loss + kl_div
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
loop.set_postfix(loss=loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
model = model.to("cpu")
|
||||||
|
def inference(digit, num_examples=1):
|
||||||
|
"""
|
||||||
|
Generates (num_examples) of a particular digit.
|
||||||
|
Specifically we extract an example of each digit,
|
||||||
|
then after we have the mu, sigma representation for
|
||||||
|
each digit we can sample from that.
|
||||||
|
|
||||||
|
After we sample we can run the decoder part of the VAE
|
||||||
|
and generate examples.
|
||||||
|
"""
|
||||||
|
images = []
|
||||||
|
idx = 0
|
||||||
|
for x, y in dataset:
|
||||||
|
if y == idx:
|
||||||
|
images.append(x)
|
||||||
|
idx += 1
|
||||||
|
if idx == 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
encodings_digit = []
|
||||||
|
for d in range(10):
|
||||||
|
with torch.no_grad():
|
||||||
|
mu, sigma = model.encode(images[d].view(1, 784))
|
||||||
|
encodings_digit.append((mu, sigma))
|
||||||
|
|
||||||
|
mu, sigma = encodings_digit[digit]
|
||||||
|
for example in range(num_examples):
|
||||||
|
epsilon = torch.randn_like(sigma)
|
||||||
|
z = mu + sigma * epsilon
|
||||||
|
out = model.decode(z)
|
||||||
|
out = out.view(-1, 1, 28, 28)
|
||||||
|
save_image(out, f"generated_{digit}_ex{example}.png")
|
||||||
|
|
||||||
|
for idx in range(10):
|
||||||
|
inference(idx, num_examples=5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -23,10 +23,10 @@ class SelfAttention(nn.Module):
|
|||||||
self.head_dim * heads == embed_size
|
self.head_dim * heads == embed_size
|
||||||
), "Embedding size needs to be divisible by heads"
|
), "Embedding size needs to be divisible by heads"
|
||||||
|
|
||||||
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
self.values = nn.Linear(embed_size, embed_size)
|
||||||
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
self.keys = nn.Linear(embed_size, embed_size)
|
||||||
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
self.queries = nn.Linear(embed_size, embed_size)
|
||||||
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
|
self.fc_out = nn.Linear(embed_size, embed_size)
|
||||||
|
|
||||||
def forward(self, values, keys, query, mask):
|
def forward(self, values, keys, query, mask):
|
||||||
# Get number of training examples
|
# Get number of training examples
|
||||||
@@ -34,14 +34,14 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
|
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
|
||||||
|
|
||||||
|
values = self.values(values) # (N, value_len, embed_size)
|
||||||
|
keys = self.keys(keys) # (N, key_len, embed_size)
|
||||||
|
queries = self.queries(query) # (N, query_len, embed_size)
|
||||||
|
|
||||||
# Split the embedding into self.heads different pieces
|
# Split the embedding into self.heads different pieces
|
||||||
values = values.reshape(N, value_len, self.heads, self.head_dim)
|
values = values.reshape(N, value_len, self.heads, self.head_dim)
|
||||||
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
|
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
|
||||||
query = query.reshape(N, query_len, self.heads, self.head_dim)
|
queries = queries.reshape(N, query_len, self.heads, self.head_dim)
|
||||||
|
|
||||||
values = self.values(values) # (N, value_len, heads, head_dim)
|
|
||||||
keys = self.keys(keys) # (N, key_len, heads, head_dim)
|
|
||||||
queries = self.queries(query) # (N, query_len, heads, heads_dim)
|
|
||||||
|
|
||||||
# Einsum does matrix mult. for query*keys for each training example
|
# Einsum does matrix mult. for query*keys for each training example
|
||||||
# with every other training example, don't be confused by einsum
|
# with every other training example, don't be confused by einsum
|
||||||
|
|||||||
Reference in New Issue
Block a user