Hi there,
I am trying to build a toy variational network that, sampling from a learned Gaussian posterior distribution, approximate another unknown distribution through a series of fully connected layers. In the code below, the target distribution is defined as a N(4,3).
I am trying to do that with variational inference. However, my model does not learn the target distribution at all, in fact it does learn to ignore posterior in the input, predicting always a constant value for different samples of the posterior. I don’t want to use Maximum Likelihood because the target distribution is pretended to be unknown.
Can anyone help me understand (1) why is that happening and (2) how to do this correctly? Reproducible code below.
import numpy as np
import matplotlib.pyplot as plt
import torch
# Define the parameters of the target Gaussian distribution we want to approximate
# This distribution is meant to be unknown (e.g. an image)
MU = 4
SIGMA = 3
# Define a neural network that learns the parameters of a posterior (mu and sigma), samples
# from that distribution, and produces a scalar output after a series of projections.
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
# Prior for the KLD
self.prior = torch.distributions.Normal(0, 1);
# Gaussian posterior distribution parameters to be learned by the model
self.mu = torch.nn.Parameter(torch.normal(0, 0.1, size=(1,)), requires_grad=True)
self.var = torch.nn.Parameter(torch.ones(1), requires_grad=True)
self.posterior = torch.distributions.Normal(self.mu, self.var)
# Multi-Layer Perceptron
self.projection_1 = torch.nn.Linear(in_features=1, out_features=10)
self.projection_2 = torch.nn.Linear(in_features=10, out_features=10)
self.projection_out = torch.nn.Linear(in_features=10, out_features=1)
def forward(self):
# Sample from the posterior
z = self.posterior.sample()
# MLP Projection
h = torch.nn.ReLU()(self.projection_1(z))
h = torch.nn.ReLU()(self.projection_2(h))
# Output projection
o = self.projection_out(h)
return o
net = Net()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
# Do 50,000 steps of gradient descent
for i in range(50000):
net.zero_grad() # Reset the gradient
# Forward
y_hat = net()
# Take a sample from the target distribution
y = torch.normal(MU, SIGMA, size=(1,))
# Calculate the KLD
kld = torch.distributions.kl_divergence(net.prior, net.posterior)
# Calculate Loss
loss = criterion(y_hat, y) + kld/100
# Backward
loss.backward()
# Training step
optimizer.step()
# Now let's inspect the learned distribution
outputs = [net().cpu().detach().numpy().item() for _ in range(1000)]
plt.hist(outputs)