mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 19:27:58 +00:00
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
class VariationalAutoEncoder(nn.Module):
|
|
def __init__(self, input_dim, h_dim=200, z_dim=20):
|
|
super().__init__()
|
|
# encoder
|
|
self.img_2hid = nn.Linear(input_dim, h_dim)
|
|
self.hid_2mu = nn.Linear(h_dim, z_dim)
|
|
self.hid_2sigma = nn.Linear(h_dim, z_dim)
|
|
|
|
# decoder
|
|
self.z_2hid = nn.Linear(z_dim, h_dim)
|
|
self.hid_2img = nn.Linear(h_dim, input_dim)
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
def encode(self, x):
|
|
h = self.relu(self.img_2hid(x))
|
|
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
|
|
return mu, sigma
|
|
|
|
def decode(self, z):
|
|
h = self.relu(self.z_2hid(z))
|
|
return torch.sigmoid(self.hid_2img(h))
|
|
|
|
def forward(self, x):
|
|
mu, sigma = self.encode(x)
|
|
epsilon = torch.randn_like(sigma)
|
|
z_new = mu + sigma*epsilon
|
|
x_reconstructed = self.decode(z_new)
|
|
return x_reconstructed, mu, sigma
|
|
|
|
|
|
if __name__ == "__main__":
|
|
x = torch.randn(4, 28*28)
|
|
vae = VariationalAutoEncoder(input_dim=784)
|
|
x_reconstructed, mu, sigma = vae(x)
|
|
print(x_reconstructed.shape)
|
|
print(mu.shape)
|
|
print(sigma.shape)
|
|
|
|
|
|
|