Learn a normal distribution with MultivariateNormal(loc, covariance_matrix)

Hi everyone,

I am trying to learn a multivariate normal distribution. To simplify things, let’s say the input is fixed, and the labels are some 3d points (x, y, z).

My approach was to create a NN to output 9 values, corresponding to 3 means, 3 variances and 3 co-variances. Then at each step, do a rsample from a

MultivariateNormal(loc=m, covariance_matrix=c)

, with c being the symmetric covariance matrix created with the 6 variances and co-variances. Then I get the loss by comparing the “rsampled” 3d point, to a point from the labels.

However, I sometimes get an error at the creation of the distribution, because the covariance matrix is not always positive definite.

  1. Is this the right approach ? Or is there a proper way to do this ?
  2. If it is, how can I assure that the NN will produce valid distribution parameters ?
  3. Should I use the other arguments (precision_matrix / scale_tril) instead of covariance_matrix at the creation of the distribution ? I don’t understand their mathematical meanings …

Thank you,
PL

Ended up using the 6 parameters to generate the scale_tril instead of the covariance_matrix.
It worked really well.

@phan_phan can you explain your solution more in details? Do you do a regression output of 6 parameters and just send it to scale_tril?
how do you exactly generate scale_tril?

1 Like

Hi,
Yes I eventually did a regression over the distribution parameters loc and scale_tril.
So my last NN layer outputs a tensor of size [batch, 9].

However, a few things to note:

  • The diagonal of scale_tril must be strictly positive, so I used 1 + ELU:
loc, tril, diag = outputs.split(3, dim=-1)
diag = 1 + nn.functional.elu(diag)
# Each of loc, tril, diag is a tensor of size [batch, 3]
  • I generated scale_tril the dirty way, maybe there’s a better one:
z = torch.zeros(size=[loc.size(0)], device=device)
scale_tril = torch.stack([
    diag[:, 0], z         , z,
    tril[:, 0], diag[:, 1], z,
    tril[:, 1], tril[:, 2], diag[:, 2]
], dim=-1).view(-1, 3, 3)
# scale_tril is a tensor of size [batch, 3, 3]
dist = MultivariateNormal(loc=loc, scale_tril=scale_tril)
  • In order to train, instead of minimizing a criterion between a distribution sample and the label (like I thought in the last comment), it is better to minimize the log-likelihood of the label with respect to the distrubution, with the .log_prob() method.
loss = - dist.log_prob(labels).mean()
loss.backward()
2 Likes

Thank you so much, I will try it!