When implementing VAE, which one of torch.distributions.normal and torch.distributions.multivariate_normal.MultivariateNormal should I use?

When implementing VAE as below, I have two questions.

class VAE(nn.Module):
    def __init__(self, data_size = data_size):
        super(VAE, self).__init__()
        self.data_size = data_size
        
        self.fc1 = nn.Linear(self.data_size ** 2, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, self.data_size ** 2)
     
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.data_size ** 2))
        dist = Normal(mu, logvar.exp_())
        # dist = MultivariateNormal(mu, ??)
        return self.decode(dist.rsample()), dist
  1. In the forward function, which is correct? using normal or MultivariateNormal ?
  2. When using MultivariateNormal is correct, how should we create the same number of diagonal matrices as the batch size used for the variance covariance matrix?(The part written in “??” in the code.) I could not find a function in pytorch that could accomplish it.

Hi, did you find your solution? I wanted to post the question but I found yours.

I also would be interested in knowing if you found a solution for this.

1 Like

P(z) is a multivariate Gaussian ~N(z|0, I).
The trick is the z vectors in the batch are independent, normally-distributed values. [1]
Therefore samples of z can be drawn from a simple distribution, i.e., N(0, I) using torch.randn_like.
Given the mean and co-variance we first sample e ∼N(0, I), then
compute z = mean(X) + 0.5 sigma(X) ∗ e.

std = torch.exp(0.5 * variance_ln)
e = torch.randn_like(variance_ln).to(device) # latent prior p(z): ~N(0,I)
z = mu + e * std # “reparameterization trick”

[1] https://arxiv.org/pdf/1606.05908.pdf

1 Like