Multivariate Gaussian Variational Autoencoder (the decoder part)

I have been reading this paper: https://arxiv.org/pdf/1312.6114.pdf to build a variational autoencoder. Then, I stumbled upon the VAE example that pytorch offers: https://github.com/pytorch/examples/blob/master/vae/main.py#L39. This one is for binary data because it uses a Bernoulli distribution in the decoder (basically the application of a sigmoid activation function to the outputs). Below there is the part of the paper where they explicitly say so:

I am more interested in real-valued data (-∞, ∞) and need the decoder of this VAE to reconstruct a multivariate Gaussian distribution instead. In short – how to achieve this with the Pytorch’s example above?

I am a little confused and let me elaborate about it. For real-valued cases they wrote:

What I did was the following to achieve Eq. (12):

diff --git a/vae/main.py b/vae/main.py
index 6286592..246dd97 100644
--- a/vae/main.py
+++ b/vae/main.py
@@ -44,6 +44,7 @@ class VAE(nn.Module):
         self.fc21 = nn.Linear(400, 20)
         self.fc22 = nn.Linear(400, 20)
         self.fc3 = nn.Linear(20, 400)
+        self.fc31 = nn.Linear(20, 400)
         self.fc4 = nn.Linear(400, 784)

     def encode(self, x):
@@ -56,8 +57,10 @@ class VAE(nn.Module):
         return mu + eps*std

     def decode(self, z):
-        h3 = F.relu(self.fc3(z))
-        return torch.sigmoid(self.fc4(h3))
+        mu = F.relu(self.fc3(z))
+        logvar = F.relu(self.fc3(m))
+        gaussian = self.reparameterize(mu, logvar)
+        return self.fc4(gaussian)

     def forward(self, x):
         mu, logvar = self.encode(x.view(-1, 784))
@@ -71,15 +74,17 @@ optimizer = optim.Adam(model.parameters(), lr=1e-3)

 # Reconstruction + KL divergence losses summed over all elements and batch
 def loss_function(recon_x, x, mu, logvar):
-    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

     # see Appendix B from VAE paper:
     # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
     # https://arxiv.org/abs/1312.6114
     # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
+    criterion = torch.nn.MSELoss()
+    MSE = criterion(recon_x, x) * .5
+
     KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

-    return BCE + KLD
+    return MSE + KLD


 def train(epoch):

As you see, I removed the sigmoid activation function, I reconstruct the Gaussian by taking z and calling the reparameterize function, and used MSE loss function. Is this approach correct? Or how should the decoder look like?

I would appreciate any input about this.

I think you want

h = F.relu(self.fc3(z)) # KW use tanh instead of relu
mu = self.fc4(h)
sigma = self.fc5(h)
return mu, sigma

and then pass mu, sigma and x into your criterion to give the negative log likelihood of x as a vector sampled from N(mu_i, sigma_i^2) (i.e. you need to take the log PDF).

So the Gaussian at the reconstruction step has nothing to do (well, except being conditional on the latents) with the Gaussian from the latents (which is the bit where you do the reparametrization and things).

I think what might cause the confusion that the “reconstruction error” is easily taken too literally when looking at the example implementation. In the paper they only take and maximize the (log) likelihood of the input data under the output distribution rather than doing actual reconstruction (by sampling or somesuch). In evaluation mode, you would be expected to sample from the Gaussian.
The last paragraph of section 2.3 in the paper talks about how in the probabilistic VAE framework this term is a […] reconstruction error in auto-encoder parlance (notice the indefinite article) or maybe rather the analogue of the reconstruction error.

Best regards

Thomas

I hope I have understood correctly. I changed the VAE and now looks this way:

diff --git a/vae/main.py b/vae/main.py
index 6286592..606977b 100644
--- a/vae/main.py
+++ b/vae/main.py
@@ -45,6 +45,7 @@ class VAE(nn.Module):
         self.fc22 = nn.Linear(400, 20)
         self.fc3 = nn.Linear(20, 400)
         self.fc4 = nn.Linear(400, 784)
+        self.fc5 = nn.Linear(400, 784)
 
     def encode(self, x):
         h1 = F.relu(self.fc1(x))
@@ -56,13 +57,17 @@ class VAE(nn.Module):
         return mu + eps*std
 
     def decode(self, z):
-        h3 = F.relu(self.fc3(z))
-        return torch.sigmoid(self.fc4(h3))
+        h = F.relu(self.fc3(z))
+        mu = self.fc4(h)
+        logvar = self.fc5(h)
+        recon_x = self.reparameterize(mu, logvar)
+        return recon_x, mu, logvar
 
     def forward(self, x):
         mu, logvar = self.encode(x.view(-1, 784))
         z = self.reparameterize(mu, logvar)
-        return self.decode(z), mu, logvar
+        recon_x, mu, logvar = self.decode(z)
+        return recon_x, mu, logvar
 
 
 model = VAE().to(device)
@@ -71,7 +76,8 @@ optimizer = optim.Adam(model.parameters(), lr=1e-3)
 
 # Reconstruction + KL divergence losses summed over all elements and batch
 def loss_function(recon_x, x, mu, logvar):
-    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
+    criterion = torch.nn.MSELoss()
+    MSE = criterion(recon_x, x.view(-1, 784)) * .5
 
     # see Appendix B from VAE paper:
     # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
@@ -79,7 +85,7 @@ def loss_function(recon_x, x, mu, logvar):
     # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
     KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
 
-    return BCE + KLD
+    return MSE + KLD
 
 
 def train(epoch):
@@ -127,6 +133,8 @@ if __name__ == "__main__":
         test(epoch)
         with torch.no_grad():
             sample = torch.randn(64, 20).to(device)
-            sample = model.decode(sample).cpu()
+            sample, mu, logvar = model.decode(sample)
+            sample = sample.cpu()
+

The mu and logvar used in the loss function now come from the decoder, and in order to reconstruct X, I use self.reparameterize (not sure about this). Is this correct? I am sorry if I have misunderstood. At least these changes reduce the loss see the Bernoulli and Gaussian outputs here -> https://gist.github.com/muammar/0c0c0c53f351c85c0680017a8c41ce62.

Thanks for this clarification @tom.

You are right about logvar instead of sigma, but there should be no recon_x. Instead, the recon loss is the neg log likelihood of x under the normal dist given by mu, logsigma with independent elements.

Best regards

Thomas

You refer to something like what is shown here: Backward for negative log likelihood loss of MultivariateNormal (in distributions).

n = some_int
mu = torch.zeros(n)
C = torch.eye(n, n)
m = torch.distributions.MultivariateNormal(mu, covariance_matrix=C)
x = m.sample()  # should have shape (n,)
loss = -m.log_prob(x)  # should be a scalar

So I tried to avoid spelling out the log likelihood but so there are two sets of mu and logvar:

def decode(self, z):
    h = F.relu(self.fc3(z))
    mu = self.fc4(h)
    logvar = self.fc5(h)
    return mu, logvar
 
def forward(self, x):
    mu_latent, logvar_latent = self.encode(x.view(-1, 784))
    z = self.reparameterize(mu, logvar)
    mu_x, logvar_x = self.decode(z)
    return mu_latent, logvar_latent, mu_x, logvar_x
 
def loss_function(x, mu_x, logvar_x, mu_latent, logvar_latent):
    # neg log likelihood of x under normal
    loss_rec = LOG_2_PI + logvar_x + (x - mu_x)**2 / (2*torch.exp(logvar_x))
    KLD = -0.5 * torch.sum(1 + logvar_latent - mu_latent.pow(2) - logvar_latent.exp())
    return loss_rec + KLD

I hope this approximately makes sense.

Best regards

Thomas

1 Like

This makes lots of sense to me know. Thanks for sharing this code here, @tom. I will give it a try.

I am having cases where the total loss function becomes negative when the reconstruction is the negative log-likelihood. How should this be interpreted?

Probably a bug in my formula… Can you give the input of the loss function yielding the neg loss, please?

@tom This is the loss function code that is equivalent to https://github.com/y0ast/Variational-Autoencoder/blob/master/VAE.py#L118.

def VAELoss(
    outputs=None,
    targets=None,
    mus_latent=None,
    logvars_latent=None,
    mus_decoder=None,
    logvars_decoder=None,
    annealing=None,
    multivariate=None,
    latent=None,
    input_dimension=None,
):
    """Variational Autoencoder loss function
    Parameters
    ----------
    outputs : tensor
        Outputs of the model.
    targets : tensor
        Expected value of outputs.
    mus_latent : tensor
        Mean values of distribution.
    logvars_latent : tensor
        Logarithm of the variance.
    multivariate : bool
        If multivariate is set to True we treat the distribution as a
        multivariate Gaussian distribution otherwise we use Bernoulli.
    annealing : float
        Contribution of distance loss function to total loss.
    latent : tensor, optional
        The latent space tensor.
    input_dimension : int, optional
        Input's dimension.
    Returns
    -------
    loss : tensor
        The value of the loss function.
    """

    loss = []

    dim = 1
    if multivariate:
        # loss_rec = LOG_2_PI + logvar_x + (x - mu_x)**2 / (2*torch.exp(logvar_x))
        loss_rec = -torch.sum(
            (-0.5 * np.log(2.0 * np.pi))
            + (-0.5 * logvars_decoder)
            + ((-0.5 / torch.exp(logvars_decoder)) * (targets - mus_decoder) ** 2.0),
            dim=dim,
        )

    else:
        loss_rec = torch.nn.functional.binary_cross_entropy(
            outputs, targets, reduction="sum"
        )
        loss_rec *= input_dimension

    loss.append(loss_rec)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

    kld = (
        -0.5
        * torch.sum(
            1 + logvars_latent - mus_latent.pow(2) - logvars_latent.exp(), dim=dim
        )
        * annealing
    )
    loss.append(kld)

    if latent is not None:
        activation_reg = torch.mean(torch.pow(latent, 2), dim=dim)
        loss.append(activation_reg)

    # Mini-batch mean
    loss = torch.mean(torch.stack(loss))

    return loss

The model can be seen here: https://github.com/muammar/ml4chem/blob/master/ml4chem/models/autoencoders.py#L361

What I did was to print logvars_decoder, and mus_decoder when the loss becomes negative:

LOGVAR_DECODER
tensor([[-4.0889458656, -3.5250387192, -3.8208389282, -3.5385358334,
         -3.1981320381, -3.9037659168, -3.2332408428, -4.3280816078],
        [-4.3436975479, -3.7619643211, -4.0336480141, -3.6318600178,
         -3.4294614792, -4.0175385475, -3.2871749401, -4.5249586105],
        [-4.3395757675, -3.9168510437, -4.0671305656, -3.7074320316,
         -3.4387311935, -4.0541105270, -3.3110878468, -4.5024285316],
        [-4.1954979897, -3.6041107178, -3.9152162075, -3.5348732471,
         -3.3275499344, -3.9352111816, -3.2573788166, -4.4472432137],
        [-4.2599620819, -3.6669955254, -3.9934899807, -3.6058964729,
         -3.3935651779, -3.9977569580, -3.3050155640, -4.5190782547],
        [-4.3928508759, -3.9678094387, -4.0128798485, -3.6809661388,
         -3.4883437157, -3.9990940094, -3.3398213387, -4.5465860367],
        [-4.2483749390, -3.6474187374, -3.8801121712, -3.4824745655,
         -3.3359556198, -3.8939044476, -3.2095971107, -4.4221296310],
        [-4.1744241714, -3.6982641220, -3.9988830090, -3.6817216873,
         -3.3437509537, -4.0203957558, -3.3463952541, -4.4784379005],
        [-3.9882695675, -3.9037389755, -3.9318742752, -3.8031373024,
         -3.2512860298, -3.9384653568, -3.3627109528, -4.2665996552],
        [-4.0217909813, -3.9362194538, -3.9807562828, -3.5862305164,
         -3.1849975586, -3.8795177937, -3.2918174267, -4.1493740082],
        [-4.2809176445, -3.8154747486, -3.9485902786, -3.5474169254,
         -3.3101720810, -3.9628999233, -3.1462957859, -4.3412165642],
        [-3.9343552589, -3.7270662785, -3.8015532494, -3.7683753967,
         -3.1551008224, -3.9056501389, -3.3827788830, -4.2404427528],
        [-3.9076354504, -3.9564294815, -3.7914583683, -3.8463647366,
         -3.1535534859, -3.8839437962, -3.3489274979, -4.1176047325],
        [-3.4092268944, -3.6969037056, -3.3210926056, -3.0190851688,
         -2.5197367668, -3.4038391113, -2.6846501827, -3.1622867584],
        [-4.3182497025, -4.1536660194, -4.1363797188, -3.8312237263,
         -3.4353456497, -4.1218142509, -3.3725814819, -4.4551277161],
        [-3.8396859169, -3.7065541744, -3.7668292522, -3.8391516209,
         -3.1027283669, -3.8931672573, -3.4790892601, -4.2176709175],
        [-3.5017061234, -3.8813052177, -3.5182900429, -3.8348097801,
         -2.8906803131, -3.6051583290, -3.3228216171, -3.8034570217],
        [-3.4464774132, -3.8789443970, -3.3224365711, -3.0810019970,
         -2.5025038719, -3.5263288021, -2.5929825306, -3.0947396755],
        [-4.2273983955, -4.0144639015, -3.9377801418, -3.5958638191,
         -3.2787804604, -3.9937865734, -3.1052067280, -4.2226104736],
        [-3.3745789528, -3.4972953796, -3.3253111839, -3.6773264408,
         -2.6872353554, -3.5508143902, -3.4156613350, -3.7921562195],
        [-2.9271032810, -3.6714038849, -3.0703654289, -3.7321207523,
         -2.3481974602, -3.2294888496, -3.2276337147, -3.2842714787],
        [-3.0432860851, -3.6146023273, -3.0295176506, -2.7825140953,
         -2.1663827896, -3.1854391098, -2.2825937271, -2.5911545753],
        [-3.7111034393, -3.8651938438, -3.6059389114, -3.3985249996,
         -2.8822917938, -3.6716985703, -2.9334106445, -3.6565389633],
        [-2.9104354382, -3.4743690491, -3.0280134678, -3.5700728893,
         -2.3488564491, -3.1583023071, -3.1791927814, -3.3389244080],
        [-2.2210299969, -3.0240457058, -2.3275172710, -3.1043970585,
         -1.6053416729, -2.4352619648, -2.4883928299, -2.3523766994],
        [-2.7114429474, -3.4283807278, -2.6234874725, -2.4926660061,
         -1.8203262091, -2.7575194836, -2.1052958965, -2.1186287403],
        [-3.8372142315, -4.0330691338, -3.6443238258, -3.3687250614,
         -2.8806028366, -3.8182375431, -2.8009984493, -3.6179041862],
        [-2.9028546810, -3.5859816074, -2.9423685074, -3.5868747234,
         -2.2634677887, -3.1730742455, -3.1181771755, -3.2169914246],
        [-2.7047109604, -3.4103722572, -2.7948386669, -3.5056848526,
         -2.0805091858, -3.0460426807, -3.0462079048, -3.0641751289],
        [-1.8170139790, -3.0801727772, -1.9800815582, -2.1787593365,
         -1.1494998932, -1.8789745569, -1.3521165848, -1.1914081573],
        [-3.1940779686, -3.8410174847, -3.2339777946, -3.0565440655,
         -2.3353605270, -3.3992166519, -2.4079468250, -2.7703025341],
        [-2.5278239250, -3.2410411835, -2.4651215076, -3.1415879726,
         -1.8574959040, -2.6578209400, -2.5992963314, -2.6902410984],
        [-2.0230534077, -2.6885709763, -2.0396063328, -2.6587059498,
         -1.4356783628, -2.0917904377, -2.2107565403, -2.1953237057],
        [-2.7910606861, -3.6545739174, -2.8367061615, -2.7674615383,
         -1.9324414730, -2.9588494301, -2.1176555157, -2.1152365208],
        [-3.7650642395, -4.1399159431, -3.6829044819, -3.4644801617,
         -2.8268494606, -3.8292882442, -3.0056245327, -3.5794913769],
        [-2.4097881317, -3.1195113659, -2.4535489082, -3.1327230930,
         -1.7928742170, -2.5936071873, -2.6375653744, -2.6450741291],
        [-2.0059213638, -2.5362682343, -2.0935509205, -2.8857729435,
         -1.3728393316, -2.4956555367, -2.5161824226, -2.4067625999],
        [-2.9616751671, -3.5862176418, -2.9856770039, -2.7996230125,
         -2.1296710968, -3.1245079041, -2.2272682190, -2.5133960247],
        [-2.6331017017, -3.3296575546, -2.7890243530, -2.7522990704,
         -1.9952768087, -2.7798025608, -2.1912658215, -2.3751411438],
        [-1.9964694977, -2.6528348923, -2.0138716698, -2.6910498142,
         -1.3910135031, -2.1067337990, -2.2230579853, -2.1664700508]],
       grad_fn=<StackBackward>)

MU_DECODER
tensor([[ 7.4537128210e-01,  7.7623891830e-01,  7.6602262259e-01,
         -1.3993513584e-01,  8.1436938047e-01,  3.6314576864e-01,
          9.9604159594e-01,  7.9063363373e-02],
        [ 8.9660757780e-01,  8.9131146669e-01,  8.7661617994e-01,
         -3.6907706410e-02,  8.2370609045e-01,  5.3367781639e-01,
          9.7979170084e-01,  2.0927461982e-01],
        [ 8.5515135527e-01,  8.2580548525e-01,  8.0289411545e-01,
          4.1661325842e-02,  7.1721255779e-01,  5.4174715281e-01,
          8.7252455950e-01,  2.3034229875e-01],
        [ 8.6751633883e-01,  8.9080280066e-01,  8.8437432051e-01,
         -7.2503603995e-02,  8.5934364796e-01,  5.0289857388e-01,
          1.0022629499e+00,  1.8224236369e-01],
        [ 8.8764733076e-01,  9.1166216135e-01,  9.1891723871e-01,
         -7.7194452286e-02,  8.8277554512e-01,  5.2277541161e-01,
          1.0235692263e+00,  1.9542214274e-01],
        [ 8.9673584700e-01,  8.6170548201e-01,  8.2436698675e-01,
          1.4161440730e-01,  6.5297460556e-01,  6.3108932972e-01,
          7.9818397760e-01,  2.6648667455e-01],
        [ 8.8351935148e-01,  8.8325554132e-01,  8.5816204548e-01,
         -8.0196680501e-03,  7.9587674141e-01,  5.3751409054e-01,
          9.4600826502e-01,  2.0341506600e-01],
        [ 8.1397992373e-01,  8.2838648558e-01,  8.5030966997e-01,
         -8.8126033545e-02,  8.3583903313e-01,  4.5510044694e-01,
          9.9602252245e-01,  1.5417987108e-01],
        [ 6.1925458908e-01,  5.9385389090e-01,  5.2979224920e-01,
         -3.5912264138e-02,  4.9642848969e-01,  3.3578324318e-01,
          7.0547688007e-01,  7.8385539353e-02],
        [ 7.8108674288e-01,  6.6775715351e-01,  6.1285018921e-01,
          3.1544300914e-01,  3.6865007877e-01,  6.7531591654e-01,
          6.4135974646e-01,  3.5447835922e-01],
        [ 8.3815830946e-01,  7.8540831804e-01,  7.2268778086e-01,
          5.5315937847e-02,  6.7108374834e-01,  4.9458086491e-01,
          8.4645766020e-01,  2.3364454508e-01],
        [ 5.1406806707e-01,  5.2309536934e-01,  5.0807440281e-01,
         -1.9287040830e-01,  5.8621060848e-01,  1.5216127038e-01,
          8.3681327105e-01, -6.7267514765e-02],
        [ 4.1226357222e-01,  3.9164602757e-01,  3.0982518196e-01,
         -1.3932901621e-01,  3.6478865147e-01,  1.1988976598e-01,
          6.1140972376e-01, -7.3039494455e-02],
        [ 4.7460585833e-01,  4.3725985289e-01,  2.8978574276e-01,
          7.2090071440e-01, -2.4166262150e-01,  8.2806438208e-01,
          6.1434850097e-02,  6.0027962923e-01],
        [ 7.9462063313e-01,  7.3788356781e-01,  6.8965417147e-01,
          1.9477413595e-01,  5.3217446804e-01,  5.9116268158e-01,
          6.9491815567e-01,  2.9300740361e-01],
        [ 4.0884870291e-01,  4.2614352703e-01,  4.3263792992e-01,
         -2.7052247524e-01,  5.5374783278e-01,  2.5412589312e-02,
          8.5867840052e-01, -1.6064369678e-01],
        [ 1.1232741177e-01,  9.4285562634e-02, -2.8797790408e-02,
         -2.8413486481e-01,  1.1447374523e-01, -1.2736788392e-01,
          3.8964253664e-01, -2.7499958873e-01],
        [ 4.0609455109e-01,  4.1544246674e-01,  2.3380063474e-01,
          8.1382852793e-01, -3.6809879541e-01,  8.4226578474e-01,
         -9.0999066830e-02,  6.8012529612e-01],
        [ 7.5890117884e-01,  7.1640801430e-01,  6.4160078764e-01,
          2.2286985815e-01,  4.8637318611e-01,  5.6602346897e-01,
          6.3253593445e-01,  3.3189448714e-01],
        [ 1.8294855952e-02,  4.1917800903e-02,  1.9956126809e-02,
         -4.5387506485e-01,  3.2704287767e-01, -3.3966571093e-01,
          7.0831686258e-01, -4.3913036585e-01],
        [-3.8637369871e-01, -4.5166379213e-01, -4.6067780256e-01,
         -5.8770400286e-01, -3.9115294814e-02, -5.3764253855e-01,
          2.3831032217e-01, -6.1640590429e-01],
        [ 2.9319143295e-01,  2.8827565908e-01,  6.5197840333e-02,
          8.6398667097e-01, -5.9744858742e-01,  8.3846032619e-01,
         -2.4567192793e-01,  6.7352229357e-01],
        [ 5.2184075117e-01,  5.0640445948e-01,  3.5496354103e-01,
          3.6912393570e-01,  8.2235828042e-02,  5.9509968758e-01,
          2.9390001297e-01,  3.5063615441e-01],
        [-2.9110723734e-01, -3.1803846359e-01, -3.5329234600e-01,
         -5.5038774014e-01,  2.4381056428e-02, -5.0387620926e-01,
          3.2649433613e-01, -5.5301195383e-01],
        [-7.2178471088e-01, -8.0932444334e-01, -7.1661406755e-01,
         -6.4933156967e-01, -2.3751400411e-01, -5.6131917238e-01,
         -1.1282922328e-01, -7.3082917929e-01],
        [ 1.5074957907e-01,  9.7607210279e-02, -1.0853867233e-01,
          1.0877382755e+00, -1.0112128258e+00,  9.1293346882e-01,
         -5.6111729145e-01,  7.7494198084e-01],
        [ 5.4971641302e-01,  5.4886144400e-01,  4.1464430094e-01,
          5.7090681791e-01,  1.5562146902e-02,  7.1120524406e-01,
          1.9895754755e-01,  5.4621404409e-01],
        [-3.9366084337e-01, -4.4121766090e-01, -4.5996707678e-01,
         -5.5510425568e-01, -1.1685043573e-03, -5.3734993935e-01,
          2.4105130136e-01, -6.1151641607e-01],
        [-5.2276879549e-01, -5.4054433107e-01, -5.0406098366e-01,
         -6.1084896326e-01, -1.4487281442e-02, -6.0916495323e-01,
          2.0622767508e-01, -6.7569190264e-01],
        [-3.6969506741e-01, -4.5419234037e-01, -8.5087037086e-01,
          5.9729576111e-01, -1.3964147568e+00,  4.8319572210e-01,
         -1.0880144835e+00,  3.2983240485e-01],
        [ 2.5268208981e-01,  2.6167953014e-01,  2.8274446726e-02,
          7.4333733320e-01, -5.5177694559e-01,  7.4476432800e-01,
         -2.3858781159e-01,  5.9774476290e-01],
        [-5.3853386641e-01, -5.8221745491e-01, -5.9578448534e-01,
         -5.2591758966e-01, -1.6031323373e-01, -4.9402362108e-01,
         -3.4457042813e-02, -6.1920011044e-01],
        [-6.1508226395e-01, -7.0803725719e-01, -7.0008456707e-01,
         -5.6278920174e-01, -2.5570809841e-01, -5.3972440958e-01,
         -1.1528758705e-01, -6.3002502918e-01],
        [ 5.8707669377e-02,  1.8676459789e-02, -2.2664265335e-01,
          9.2681169510e-01, -1.0662351847e+00,  8.2683765888e-01,
         -5.8269971609e-01,  6.1325526237e-01],
        [ 4.6706598997e-01,  4.5936805010e-01,  3.1905025244e-01,
          7.0009756088e-01, -1.6421130300e-01,  8.0398738384e-01,
          7.9242959619e-02,  5.8904093504e-01],
        [-5.7732629776e-01, -6.3657850027e-01, -6.1326271296e-01,
         -5.9521275759e-01, -1.5290783346e-01, -5.5260998011e-01,
          1.1421605945e-02, -6.6122150421e-01],
        [-8.7108415365e-01, -7.0264959335e-01, -5.4622644186e-01,
         -7.2423720360e-01,  6.8168237805e-02, -7.8077685833e-01,
          1.3776995242e-01, -7.8278994560e-01],
        [ 2.2678442299e-01,  2.3853041232e-01, -1.2939646840e-02,
          7.8404253721e-01, -6.2871348858e-01,  7.6879101992e-01,
         -2.8913736343e-01,  5.9305030107e-01],
        [ 1.0347022116e-01,  1.4599119127e-01, -1.9325403869e-01,
          5.6087404490e-01, -6.3103687763e-01,  6.0161834955e-01,
         -3.6667895317e-01,  4.0997803211e-01],
        [-6.8918240070e-01, -7.5843018293e-01, -7.0153790712e-01,
         -6.0384953022e-01, -2.2374112904e-01, -5.8500319719e-01,
         -1.0391147435e-01, -6.8481385708e-01]], grad_fn=<StackBackward>)

The data I am using with this VAE was preprocessed with MaxMinScaler as implemented in scikit-learn to a range (-1, 1). The problem I see here is that logsvar_decoder became negative and large. I will investigate more about this problem. I suppose in principle I could “penalize” logsvar_decoder to remain small?

I checked again the loss function formula and it seems fine. KL divergence always is positive but the reconstruction is changing sign. No idea why would that be.

For very peaked normals, the log density can become > 0.

I have same problem, reconstruction loss becomes after few epochs negarive. I don´t have any idea how to fix it.

But are you sure it is unexpected?
Does the VAE learn?

I do seem to recall that the loss could become negative in GPs, too, so there isn’t an inherent reason why it should not here.

Best regards

Thomas

Thank you, I am starting with VAE and I’ve never encountered negative loss before. However VAE learns!


Personally, I’d think it’s OK then. If your std is 0.1 (if the shaded region is one std) or 0.05 (if it’s two stds), the density peaks at ~4 or ~8, so one might expect a positive log likelihood.

Best regards

Thomas

That region is 4 * std thick (mean-2 * std, mean+2 * std)

In my use case, the VAE learned stuff. As you mentioned above, this might indicate the normal distribution is very narrow/peaked (e.g. variances < 0.5). The log of small numbers becomes large (logvar_x) in the log-likelihood function and therefore that term dominates. At this point, I think having a negative reconstruction is ok mathematically speaking but I still might be wrong.