Simple variational net with torch.distributions

Hi there,

I am trying to build a toy variational network that, sampling from a learned Gaussian posterior distribution, approximate another unknown distribution through a series of fully connected layers. In the code below, the target distribution is defined as a N(4,3).

I am trying to do that with variational inference. However, my model does not learn the target distribution at all, in fact it does learn to ignore posterior in the input, predicting always a constant value for different samples of the posterior. I don’t want to use Maximum Likelihood because the target distribution is pretended to be unknown.

Can anyone help me understand (1) why is that happening and (2) how to do this correctly? Reproducible code below.

import numpy as np
import matplotlib.pyplot as plt
import torch

# Define the parameters of the target Gaussian distribution we want to approximate
# This distribution is meant to be unknown (e.g. an image)
MU = 4
SIGMA = 3

# Define a neural network that learns the parameters of a posterior (mu and sigma), samples
# from that distribution, and produces a scalar output after a series of projections.

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Prior for the KLD
        self.prior = torch.distributions.Normal(0, 1);
        # Gaussian posterior distribution parameters to be learned by the model
        self.mu = torch.nn.Parameter(torch.normal(0, 0.1, size=(1,)), requires_grad=True)
        self.var = torch.nn.Parameter(torch.ones(1), requires_grad=True)
        self.posterior = torch.distributions.Normal(self.mu, self.var)
        # Multi-Layer Perceptron 
        self.projection_1 = torch.nn.Linear(in_features=1, out_features=10)
        self.projection_2 = torch.nn.Linear(in_features=10, out_features=10)
        self.projection_out = torch.nn.Linear(in_features=10, out_features=1)

    def forward(self):
        # Sample from the posterior
        z = self.posterior.sample()
        # MLP Projection
        h = torch.nn.ReLU()(self.projection_1(z))
        h = torch.nn.ReLU()(self.projection_2(h))
        # Output projection
        o = self.projection_out(h)
        return o

net = Net()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

# Do 50,000 steps of gradient descent

for i in range(50000):
      net.zero_grad()  # Reset the gradient
      # Forward
      y_hat = net()  
      # Take a sample from the target distribution
      y = torch.normal(MU, SIGMA, size=(1,)) 
      # Calculate the KLD
      kld = torch.distributions.kl_divergence(net.prior, net.posterior)
      # Calculate Loss
      loss = criterion(y_hat, y) + kld/100
      # Backward
      loss.backward() 
      # Training step
      optimizer.step()

# Now let's inspect the learned distribution
outputs = [net().cpu().detach().numpy().item() for _ in range(1000)]
plt.hist(outputs)