Optimized MultivariateNormal with diagonal covariance matrix

The current MultivariateNormal implementation is adding a significant overhead to my code when I use large batch sizes and after looking into the source, it seems the main cause is in one helper function used in the log_prob method, which runs an explicit for loop for a batch operation:

The other issue is that I just need a diagonal covariance for my current model, but the MultivariateNormal object is very general and runs some unnecessary computations that could be optimized for a diagonal covariance, like torch.trtrs. Would it make sense to have a MultivariateNormal implementation with some optimizations for strictly diagonal covariances? I’ve noticed there’s a new LowRankMultivariateNormal in the master branch that hasn’t made it into the stable release yet. I believe that implementation might be more suitable, the constructor takes a cov_diag explicitly, but it also takes a cov_factor, which might run some unnecessary computations for a strictly diagonal covariance as well:

What’s the recommended approach to create an efficient multivariate normal distribution with a strictly diagonal covariance matrix?


You can just use torch.distributions.Normal in that case

import torch
n = 2
d = 5
diagonal = torch.rand(d) + 1.
mu = torch.rand(n, d)
p1 = torch.distributions.Normal(mu, diagonal.reshape(1, d))
p2 = torch.distributions.MultivariateNormal(mu, scale_tril=torch.diag(diagonal).reshape(1, d, d))
x = torch.rand((n,d))
print(p1.log_prob(x).sum(dim=1) - p2.log_prob(x))
1 Like

The above answer is the way to go, but it can result in confusing issues if you’re new to torch.distributions.

In order to get the correct Kullback-Leibler divergence (and the correct shape of .log_prob), we need to wrap the Normal in the Independent class which reinterprets some number of batch dimensions as event dimensions. This is not done by the default Normal which simply assumes all dimensions to be batch-dimensions which is generally not the behaviour you want when you’re effectively defining a multivariate normal with diagonal covariance. See below and https://pytorch.org/docs/stable/distributions.html#independent

>>> loc = torch.zeros(3)
>>> scale = torch.ones(3)
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size(()), torch.Size((3,))]
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size((3,)), torch.Size(())]
>>> diagn = Independent(normal, 1)
>>> [diagn.batch_shape, diagn.event_shape]
[torch.Size(()), torch.Size((3,))]

Please use JakobHavtorn’s approach, not mine, that’s the most appropriate way to go as of now.

It is clear that axis-aligned gaussian can be created using torch.distributions.Normal and if we need to compute KL we can wrap it into Independent with reinterpretation.

My questions is: can we achieve the same result for a given mean and sigma using the torch.distributions.MultivariateNormal ?

example case:

batch_size = 5
event dims = 2

>>> mu = torch.ones([5,2])
>>> log_sig = torch.ones([5,2])
>>> indep = Independent(Normal(loc=mu, scale=torch.exp(log_sig)), 1)
>>> indep.event_shape, indep.batch_shape
(torch.Size([2]), torch.Size([5]))

>>> mvn = Multivariate(mu, scale_tril=torch.diag(log_sig))  # THIS FAILS obviously

How can we create a MultivariateNormal instance out of the given mu and sigma with batch_size=5 ?

How can a MultivariateNormal be instantiated so that it infers a batch_shape=torch.Size([5]) in this case?

1 Like

EDIT: I simplified the question here: