VAE with Multivariate Gaussian problems

I am using Variational Autoencoders(VAE) to reconstruct some spectral data, to find lower(latent) dimensional representation out of the encoder.
With mu and std from two FC layer outputs, I am getting the normal distribution q(z|x):

q = torch.distributions.Normal(mu, std)
z = q.rsample()

and optimizing KL divergence with p(x_hat|z):

p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))

KL divergence:

log_qzx = q.log_prob(z)
log_pz = p.log_prob(z)
kl = (log_qzx - log_pz)
kl = kl.sum(-1)

Also, I am using, reconstruction loss as a regularizer:

def gaussian_likelihood(x_hat, logscale, x):
    scale = torch.exp(logscale) # simply unit
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)
    # measure p(x|z)
    log_pxz = dist.log_prob(x)
    return log_pxz.sum(dim=(1))

The final loss is:
KL + recon
which I am minimizing with Adam optimizer with standard hyperparameters.
To be fair, I am getting good results and expected outcomes from the above.

I wanted to implement the same strategy with Multivariate Normal distribution. To do this I did the following:

def kl_divergence( z, mu, lt_d_Chol):
    b, l = mu.size()
    p = torch.distributions.multivariate_normal.MultivariateNormal(loc=torch.zeros_like(mu),
                            covariance_matrix=torch.eye(l).reshape((1, l, l)).repeat(b, 1, 1).to(mu.get_device()))
    q = torch.distributions.multivariate_normal.MultivariateNormal(loc=mu, scale_tril=lt_d_Chol)
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)
    kl = (log_qzx - log_pz)
    return kl

lt_d_chol is a rearrangement of an FC layer into lower-diagonal Cholesky decomposition, taken as:

def get_low_cholesky(self, latent_dim, logSigma, lt_indices, d_indices):
    """
    output -> torch from ReLU in GPU
    lt_indices -> numpy -> np.tril_indices(latent_dim, k=-1, m=latent_dim)
    d_indices -> numpy -> np.diag_indices(latent_dim)
    """
    # elements = int(latent_dim * (latent_dim - 1) / 2) + latent_dim
    Sigma = logSigma.cpu()  # torch.exp(0.5 * logSigma) # todo: needed?
    bSize = Sigma.size(0)
    low_tri_Chol = torch.zeros(bSize, latent_dim, latent_dim)
    for b in range(bSize):
        b_lower_mat = low_tri_Chol[b, ...]
        b_lower_mat[lt_indices] = Sigma[b, :-latent_dim]
        b_lower_mat[d_indices] = 1 + nn.functional.elu(Sigma[b, -latent_dim:]) # elu coz diagonal values must be positive... 
        low_tri_Chol[b, ...] = b_lower_mat
    low_tri_Chol = low_tri_Chol.to(logSigma.get_device())
    return low_tri_Chol

Similarly the recon loss:

def gaussian_likelihood(self, x_hat, x):
    mean = x_hat
    b, l = mean.size()
    dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=mean,
                    covariance_matrix=torch.eye(l).reshape((1, l, l)).repeat(b, 1, 1).to(mean.get_device()))
    log_pxz = dist.log_prob(x)
    return log_pxz

But the same hyperparameter and optimizer fail to minimize loss similarly to the univariate normal distribution.

  • Is there anything wrong conceptually to assume that 1D vector/spectral data can be better reconstructed from multivariate distributions than univariate?
  • The output logSigma or Sigma comes from ReLU-> FC layer so is it necessary to ELU as Cholesky factorized matrix of covariance-variance diagonal must be positive?
  • The runtime of the multivariate codes is pretty slower compared to univariate… Is it because there is a bunch of switching of tensors from CPU to GPU and vice versa?

Please see the loss curve:
for univariate:
Image
for multivariate(too slow so only 80 epochs ran in the same time for 1000 univariate)
Image

Reason for slow run could be the large data size of cov-var or lower triangular matrix, for example, these are of dimensions:
batchSize x latent_dim x latent_dim