I am trying to learn a multivariate normal distribution. To simplify things, let’s say the input is fixed, and the labels are some 3d points (x, y, z).
My approach was to create a NN to output 9 values, corresponding to 3 means, 3 variances and 3 co-variances. Then at each step, do a rsample from a
, with c being the symmetric covariance matrix created with the 6 variances and co-variances. Then I get the loss by comparing the “rsampled” 3d point, to a point from the labels.
However, I sometimes get an error at the creation of the distribution, because the covariance matrix is not always positive definite.
- Is this the right approach ? Or is there a proper way to do this ?
- If it is, how can I assure that the NN will produce valid distribution parameters ?
- Should I use the other arguments (precision_matrix / scale_tril) instead of covariance_matrix at the creation of the distribution ? I don’t understand their mathematical meanings …
Ended up using the 6 parameters to generate the
scale_tril instead of the
It worked really well.
@phan_phan can you explain your solution more in details? Do you do a regression output of 6 parameters and just send it to scale_tril?
how do you exactly generate scale_tril?
Yes I eventually did a regression over the distribution parameters
So my last NN layer outputs a tensor of size
However, a few things to note:
- The diagonal of
scale_tril must be strictly positive, so I used 1 + ELU:
loc, tril, diag = outputs.split(3, dim=-1)
diag = 1 + nn.functional.elu(diag)
# Each of loc, tril, diag is a tensor of size [batch, 3]
- I generated
scale_tril the dirty way, maybe there’s a better one:
z = torch.zeros(size=[loc.size(0)], device=device)
scale_tril = torch.stack([
diag[:, 0], z , z,
tril[:, 0], diag[:, 1], z,
tril[:, 1], tril[:, 2], diag[:, 2]
], dim=-1).view(-1, 3, 3)
# scale_tril is a tensor of size [batch, 3, 3]
dist = MultivariateNormal(loc=loc, scale_tril=scale_tril)
- In order to train, instead of minimizing a criterion between a distribution sample and the label (like I thought in the last comment), it is better to minimize the log-likelihood of the label with respect to the distrubution, with the .log_prob() method.
loss = - dist.log_prob(labels).mean()
Thank you so much, I will try it!