Hi,
I have a set of k MultivariateNormal distributions in d dimension.
mu = torch.FloatTensor(k, d)
sigma = torch.FloatTensor(k, d, d)
...
D = torch.distributions.MultivariateNormal(loc=mu, scale_tril=sigma)
I have a batch of N d-dimensionnal samples, and I want to get the log_prob for each of the distributions (so k values per sample).
When I do
D.log_prob(batch)
With N = 64, k=8 and d=16, I get the following error :
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([64, 16]) vs torch.Size([8, 16]).
I’ve tried expanding the distribution to match the batch shape but I get the error is equivalent:
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([64, 16]) vs torch.Size([64, 8, 16]).
Is there a way to do it that doesn’t involve iterating through the batch and calling log_prob N times ?