I am trying to build a toy variational network that, sampling from a learned Gaussian posterior distribution, approximate another unknown distribution through a series of fully connected layers. In the code below, the target distribution is defined as a N(4,3).
I am trying to do that with variational inference. However, my model does not learn the target distribution at all, in fact it does learn to ignore posterior in the input, predicting always a constant value for different samples of the posterior. I don’t want to use Maximum Likelihood because the target distribution is pretended to be unknown.
Can anyone help me understand (1) why is that happening and (2) how to do this correctly? Reproducible code below.
import numpy as np import matplotlib.pyplot as plt import torch # Define the parameters of the target Gaussian distribution we want to approximate # This distribution is meant to be unknown (e.g. an image) MU = 4 SIGMA = 3 # Define a neural network that learns the parameters of a posterior (mu and sigma), samples # from that distribution, and produces a scalar output after a series of projections. class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() # Prior for the KLD self.prior = torch.distributions.Normal(0, 1); # Gaussian posterior distribution parameters to be learned by the model self.mu = torch.nn.Parameter(torch.normal(0, 0.1, size=(1,)), requires_grad=True) self.var = torch.nn.Parameter(torch.ones(1), requires_grad=True) self.posterior = torch.distributions.Normal(self.mu, self.var) # Multi-Layer Perceptron self.projection_1 = torch.nn.Linear(in_features=1, out_features=10) self.projection_2 = torch.nn.Linear(in_features=10, out_features=10) self.projection_out = torch.nn.Linear(in_features=10, out_features=1) def forward(self): # Sample from the posterior z = self.posterior.sample() # MLP Projection h = torch.nn.ReLU()(self.projection_1(z)) h = torch.nn.ReLU()(self.projection_2(h)) # Output projection o = self.projection_out(h) return o net = Net() criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.01) # Do 50,000 steps of gradient descent for i in range(50000): net.zero_grad() # Reset the gradient # Forward y_hat = net() # Take a sample from the target distribution y = torch.normal(MU, SIGMA, size=(1,)) # Calculate the KLD kld = torch.distributions.kl_divergence(net.prior, net.posterior) # Calculate Loss loss = criterion(y_hat, y) + kld/100 # Backward loss.backward() # Training step optimizer.step() # Now let's inspect the learned distribution outputs = [net().cpu().detach().numpy().item() for _ in range(1000)] plt.hist(outputs)