When implementing VAE as below, I have two questions.
class VAE(nn.Module):
def __init__(self, data_size = data_size):
super(VAE, self).__init__()
self.data_size = data_size
self.fc1 = nn.Linear(self.data_size ** 2, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, self.data_size ** 2)
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.data_size ** 2))
dist = Normal(mu, logvar.exp_())
# dist = MultivariateNormal(mu, ??)
return self.decode(dist.rsample()), dist
- In the forward function, which is correct? using normal or MultivariateNormal ?
- When using MultivariateNormal is correct, how should we create the same number of diagonal matrices as the batch size used for the variance covariance matrix?(The part written in “??” in the code.) I could not find a function in pytorch that could accomplish it.