Files
Machine-Learning-Collection/ML/Pytorch/more_advanced/VAE/model.py

46 lines
1.2 KiB
Python
Raw Normal View History

2022-09-13 13:04:49 +02:00
import torch
from torch import nn
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_dim, h_dim=200, z_dim=20):
super().__init__()
# encoder
self.img_2hid = nn.Linear(input_dim, h_dim)
self.hid_2mu = nn.Linear(h_dim, z_dim)
self.hid_2sigma = nn.Linear(h_dim, z_dim)
# decoder
self.z_2hid = nn.Linear(z_dim, h_dim)
self.hid_2img = nn.Linear(h_dim, input_dim)
self.relu = nn.ReLU()
def encode(self, x):
h = self.relu(self.img_2hid(x))
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
return mu, sigma
def decode(self, z):
h = self.relu(self.z_2hid(z))
return torch.sigmoid(self.hid_2img(h))
def forward(self, x):
mu, sigma = self.encode(x)
epsilon = torch.randn_like(sigma)
z_new = mu + sigma*epsilon
x_reconstructed = self.decode(z_new)
return x_reconstructed, mu, sigma
if __name__ == "__main__":
x = torch.randn(4, 28*28)
vae = VariationalAutoEncoder(input_dim=784)
x_reconstructed, mu, sigma = vae(x)
print(x_reconstructed.shape)
print(mu.shape)
print(sigma.shape)