Efficiently computing log probabilities of a batch of samples points on K different Multivariate Gaussian distributions

Suppose I have N 2D Multivariate Gaussian distribution. I can define this using torch.distributions.multivariate_normal.MultivariateNormal as

inp_dim=2
mu = nn.Parameter(torch.rand(n_gaussians, inp_dim))
cov = nn.Parameter(torch.eye(inp_dim).unsqueeze(0).repeat(n_gaussians, 1, 1)

mvn = MultivariateNormal(mu, scale_tril=torch.tril(cov))

Suppose I have B sample points and I want to evaluate log probabilities of these B sample points at all of these N 2D Multivariate Gaussian distributions i.e. I will have N*B log probabilities. Simply calling the following will not work as it assumes that n_gaussians is the batch shape but in my case, N is different from B.

log_prob = mvn.log_prob(x)

How to efficiently evaluate these log probabilities of B sample points at N multivariate Gaussians?

Hi Elkop!

Add a singleton dimension to x that will broadcast over the n_gaussians
dimension of mvn. That way the n_gaussians and “batch” dimensions will
be kept separate and you will get your full set of N*B log probabilities.

Consider:

>>> import torch
>>> print (torch.__version__)
2.1.2
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> inp_dim = 2
>>> n_gaussians = 5
>>> n_batch = 3
>>>
>>> mu = torch.nn.Parameter (torch.rand (n_gaussians, inp_dim))
>>> cov = torch.nn.Parameter (torch.eye (inp_dim).unsqueeze (0).repeat (n_gaussians, 1, 1))
>>>
>>> mvn = torch.distributions.multivariate_normal.MultivariateNormal (mu, scale_tril = torch.tril (cov))
>>>
>>> x = torch.randn (n_batch, 1, inp_dim)   # singleton dimension will broadcast over mvn's n_batch
>>> mvn.log_prob (x)
tensor([[-1.9006, -2.4927, -2.0306, -1.8988, -1.9011],
        [-2.9162, -3.7090, -2.4418, -2.4743, -2.8308],
        [-1.9674, -2.0202, -1.8672, -1.8966, -1.9255]], grad_fn=<SubBackward0>)

Best.

K. Frank

1 Like