I’m trying to get a mixture density network to approximate multivariate distributions. As a pedagogic, toy-example, I’m considering a noisy linear distribution.
As a baseline, I’m fitting this with a basic model:
baseline = nn.Sequential(nn.Linear(1,32), nn.ReLU(), nn.Linear(32,1))
Which allows me to get:
Now, I’m creating a mixture model as such:
class MixtureModel(nn.Module): def __init__(self, k = 5): super().__init__() self.base = nn.Sequential(nn.Linear(1,128), nn.ReLU(), nn.Linear(128,128), nn.ReLU()) self.means = nn.Linear(128,k) self.stds = nn.Linear(128,k) self.weights = nn.Linear(128,k) def forward(self, x): latent = self.base(x) means = self.means(latent) stds = F.elu(self.stds(latent)) + 1 w = F.softmax(self.weights(latent), dim = 1) return means, stds, w
Training goes like this, maximizing log likelihood of observed data:
model = MixtureModel() adam = optim.Adam(model.parameters(), lr = 1e-2) loader = DataLoader(DS(), batch_size = 64, shuffle = True) epochs = 50 losses =  for epoch in range(epochs): epoch_loss =  for xx,yy in loader: means, sigma, w = model(xx) comp = Normal(means, sigma) mix = Categorical(w) gmm = MixtureSameFamily(mix, comp) likelihood = gmm.log_prob(yy) loss = -likelihood.mean() adam.zero_grad() loss.backward() adam.step() epoch_loss.append(loss.item()) losses.append(np.mean(epoch_loss))
However, after a few iterations of decreasing loss, the performance plateaus and the result are deceiving. Specifically, I was expecting the variance to grow with x, but the network is predicting constants means, weights and variance for all x.
Similar behaviours are observed when I increase the number of gaussians.
Could someone point me in the right direction ?
Thanks a lot ((: