Scoring batch samples using MultivariateNormal

Hi,
I have a set of k MultivariateNormal distributions in d dimension.

mu = torch.FloatTensor(k, d)
sigma = torch.FloatTensor(k, d, d)
...
D = torch.distributions.MultivariateNormal(loc=mu, scale_tril=sigma)

I have a batch of N d-dimensionnal samples, and I want to get the log_prob for each of the distributions (so k values per sample).

When I do

D.log_prob(batch)

With N = 64, k=8 and d=16, I get the following error :

ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([64, 16]) vs torch.Size([8, 16]).

I’ve tried expanding the distribution to match the batch shape but I get the error is equivalent:

ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([64, 16]) vs torch.Size([64, 8, 16]).

Is there a way to do it that doesn’t involve iterating through the batch and calling log_prob N times ?

1 Like