Hi guys, I am in what I believe to be a tricky situation. I am using an VAE based net to generate the mu and sigma parameters (both 2D, that is (x, y)), and from this values I am creating a distribution:
m = MultivariateNormal(mu, torch.diag(sigma))
So to get the probability of any tuple value (x,y) I do this
prob = m.log_prob((x,y)).exp()
Suppose I have a batch B (N x Channels x W x H) of images, I pass through the VAE and get the parameters mu and sigma. Then I want to create N tensors of size W x H of prob values. For example, suppose for image i of size W x H in the batch, I create a tensor T of same size (W x H), but in position (i,j) of T (0 <= j < W, 0 <= i < H) I have prob = m.log_prob((i,j)).exp()
. And this goes for each pair of position for each image in the batch. Then I multiply T and each channel of image i. And T’ for image i’ and so on, that is, a different T for a different image in the batch.
The resulting batch (B x Tensor of T’s) I pass to other nets to do other stuffs. How do I do this in order to backprop the values from the following nets to the VAE?
Just a reminder, as I am working with batches, m.sample() returns a sample of shape (Batch size, 2), and therefore m.log_prob() receives a tensor in the same shape.