Creating a Tensor from MultivariateNormal

Hi guys, I am in what I believe to be a tricky situation. I am using an VAE based net to generate the mu and sigma parameters (both 2D, that is (x, y)), and from this values I am creating a distribution:

m = MultivariateNormal(mu, torch.diag(sigma))

So to get the probability of any tuple value (x,y) I do this

prob = m.log_prob((x,y)).exp()

Suppose I have a batch B (N x Channels x W x H) of images, I pass through the VAE and get the parameters mu and sigma. Then I want to create N tensors of size W x H of prob values. For example, suppose for image i of size W x H in the batch, I create a tensor T of same size (W x H), but in position (i,j) of T (0 <= j < W, 0 <= i < H) I have prob = m.log_prob((i,j)).exp(). And this goes for each pair of position for each image in the batch. Then I multiply T and each channel of image i. And T’ for image i’ and so on, that is, a different T for a different image in the batch.

The resulting batch (B x Tensor of T’s) I pass to other nets to do other stuffs. How do I do this in order to backprop the values from the following nets to the VAE?

Just a reminder, as I am working with batches, m.sample() returns a sample of shape (Batch size, 2), and therefore m.log_prob() receives a tensor in the same shape.

I solved the problem with the code bellow

def generate_grid(h, w):
    x = torch.arange(0, h)
    y = torch.arange(0, w)

    grid = torch.stack([y.expand(h,-1).t().flatten(), x.repeat(w)]).t()
    return grid.float()

evaluate_mn_dist = lambda mu, sigma, tensor: MultivariateNormal(mu, torch.diag(sigma)).log_prob(tensor).exp()

w, h = imgs[0].shape[1:] # imgs is the 4-D batch (B, C, W, H)
grid = generate_grid(h, w)

# for each image in the batch, get the probabilities from a gaussian distributions with
# parameters given from the VAE
maps = torch.stack([evaluate_mn_dist(mean, std, grid) for mean, std in zip(mu_final, sigma_final)])

# reshape the attention maps to get shape (Batchs, 1 , W, H)
maps = torch.unsqueeze(maps,1).view(-1,1,w,h)
# multiply the images with the maps
imgs_ = imgs*maps