Suppose I have B sample points and I want to evaluate log probabilities of these B sample points at all of these N 2D Multivariate Gaussian distributions i.e. I will have N*B log probabilities. Simply calling the following will not work as it assumes that n_gaussians is the batch shape but in my case, N is different from B.
log_prob = mvn.log_prob(x)
How to efficiently evaluate these log probabilities of B sample points at N multivariate Gaussians?
Add a singleton dimension to x that will broadcast over the n_gaussians
dimension of mvn. That way the n_gaussians and “batch” dimensions will
be kept separate and you will get your full set of N*B log probabilities.