The current MultivariateNormal implementation is adding a significant overhead to my code when I use large batch sizes and after looking into the source, it seems the main cause is in one helper function used in the log_prob method, which runs an explicit for loop for a batch operation:
The other issue is that I just need a diagonal covariance for my current model, but the MultivariateNormal object is very general and runs some unnecessary computations that could be optimized for a diagonal covariance, like torch.trtrs. Would it make sense to have a MultivariateNormal implementation with some optimizations for strictly diagonal covariances? I’ve noticed there’s a new LowRankMultivariateNormal in the master branch that hasn’t made it into the stable release yet. I believe that implementation might be more suitable, the constructor takes a cov_diag explicitly, but it also takes a cov_factor, which might run some unnecessary computations for a strictly diagonal covariance as well:
What’s the recommended approach to create an efficient multivariate normal distribution with a strictly diagonal covariance matrix?
The above answer is the way to go, but it can result in confusing issues if you’re new to torch.distributions.
In order to get the correct Kullback-Leibler divergence (and the correct shape of .log_prob), we need to wrap the Normal in the Independent class which reinterprets some number of batch dimensions as event dimensions. This is not done by the default Normal which simply assumes all dimensions to be batch-dimensions which is generally not the behaviour you want when you’re effectively defining a multivariate normal with diagonal covariance. See below and https://pytorch.org/docs/stable/distributions.html#independent
It is clear that axis-aligned gaussian can be created using torch.distributions.Normal and if we need to compute KL we can wrap it into Independent with reinterpretation.
My questions is: can we achieve the same result for a given mean and sigma using the torch.distributions.MultivariateNormal ?