How to instantiate axis-aligned multivariate Gaussian?

It is straightforward to do the following and wrap a Gaussian into an Independent. So if we are given a

loc = [batch_size, event_shape]
scale = [batch_size]


loc = torch.zeros(5, 3, 2)
scale = torch.ones(2)
normal = Normal(loc, scale)
normal.batch_shape, normal.event_shape
(torch.Size([5, 3, 2]), torch.Size([]))

ind = Independent(normal, 1)
ind.batch_shape, ind.event_shape
(torch.Size([5, 3]), torch.Size([2]))

This Independent distribution is identical to a MultivariateNormal defined as:
mvn = MultivariateNormal(mu, torch.diag(scale))

But, what if I have the following:


mu = torch.zeros(5, 2)
log_sigma = torch.ones(5, 2)
Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)

This works fine, however how to use MultivariateNormal is not clear:

mvn = MultivariateNormal(mu, ???)

Is it possible to use MultivariateNormal and get the same distribution? How do I provide the covariance_matrix argument given log_sigma of shape [batch_size, event_shape]


EDIT:

mu = torch.zeros(5, 2)
log_sigma = torch.ones(5, 2)
cov = torch.stack([torch.diag(sigma) for sigma in torch.exp(log_sigma)])

mvn = MultivariateNormal(mu, cov)

=> would this result in an equivalent to

Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) ??

1 Like

I think torch.diag_embed is useful in constructing the tensor of diagonal covariance matrices. Hence,

import torch
from torch.distributions import MultivariateNormal

mu = torch.zeros(5, 2)
log_sigma = torch.ones(5, 2)
cov = torch.diag_embed(log_sigma)
mvn = MultivariateNormal(mu, covariance_matrix=cov)