Multivariate Gaussian Variational Autoencoder (the decoder part)

@tom This is the loss function code that is equivalent to https://github.com/y0ast/Variational-Autoencoder/blob/master/VAE.py#L118.

def VAELoss(
    outputs=None,
    targets=None,
    mus_latent=None,
    logvars_latent=None,
    mus_decoder=None,
    logvars_decoder=None,
    annealing=None,
    multivariate=None,
    latent=None,
    input_dimension=None,
):
    """Variational Autoencoder loss function
    Parameters
    ----------
    outputs : tensor
        Outputs of the model.
    targets : tensor
        Expected value of outputs.
    mus_latent : tensor
        Mean values of distribution.
    logvars_latent : tensor
        Logarithm of the variance.
    multivariate : bool
        If multivariate is set to True we treat the distribution as a
        multivariate Gaussian distribution otherwise we use Bernoulli.
    annealing : float
        Contribution of distance loss function to total loss.
    latent : tensor, optional
        The latent space tensor.
    input_dimension : int, optional
        Input's dimension.
    Returns
    -------
    loss : tensor
        The value of the loss function.
    """

    loss = []

    dim = 1
    if multivariate:
        # loss_rec = LOG_2_PI + logvar_x + (x - mu_x)**2 / (2*torch.exp(logvar_x))
        loss_rec = -torch.sum(
            (-0.5 * np.log(2.0 * np.pi))
            + (-0.5 * logvars_decoder)
            + ((-0.5 / torch.exp(logvars_decoder)) * (targets - mus_decoder) ** 2.0),
            dim=dim,
        )

    else:
        loss_rec = torch.nn.functional.binary_cross_entropy(
            outputs, targets, reduction="sum"
        )
        loss_rec *= input_dimension

    loss.append(loss_rec)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

    kld = (
        -0.5
        * torch.sum(
            1 + logvars_latent - mus_latent.pow(2) - logvars_latent.exp(), dim=dim
        )
        * annealing
    )
    loss.append(kld)

    if latent is not None:
        activation_reg = torch.mean(torch.pow(latent, 2), dim=dim)
        loss.append(activation_reg)

    # Mini-batch mean
    loss = torch.mean(torch.stack(loss))

    return loss

The model can be seen here: https://github.com/muammar/ml4chem/blob/master/ml4chem/models/autoencoders.py#L361

What I did was to print logvars_decoder, and mus_decoder when the loss becomes negative:

LOGVAR_DECODER
tensor([[-4.0889458656, -3.5250387192, -3.8208389282, -3.5385358334,
         -3.1981320381, -3.9037659168, -3.2332408428, -4.3280816078],
        [-4.3436975479, -3.7619643211, -4.0336480141, -3.6318600178,
         -3.4294614792, -4.0175385475, -3.2871749401, -4.5249586105],
        [-4.3395757675, -3.9168510437, -4.0671305656, -3.7074320316,
         -3.4387311935, -4.0541105270, -3.3110878468, -4.5024285316],
        [-4.1954979897, -3.6041107178, -3.9152162075, -3.5348732471,
         -3.3275499344, -3.9352111816, -3.2573788166, -4.4472432137],
        [-4.2599620819, -3.6669955254, -3.9934899807, -3.6058964729,
         -3.3935651779, -3.9977569580, -3.3050155640, -4.5190782547],
        [-4.3928508759, -3.9678094387, -4.0128798485, -3.6809661388,
         -3.4883437157, -3.9990940094, -3.3398213387, -4.5465860367],
        [-4.2483749390, -3.6474187374, -3.8801121712, -3.4824745655,
         -3.3359556198, -3.8939044476, -3.2095971107, -4.4221296310],
        [-4.1744241714, -3.6982641220, -3.9988830090, -3.6817216873,
         -3.3437509537, -4.0203957558, -3.3463952541, -4.4784379005],
        [-3.9882695675, -3.9037389755, -3.9318742752, -3.8031373024,
         -3.2512860298, -3.9384653568, -3.3627109528, -4.2665996552],
        [-4.0217909813, -3.9362194538, -3.9807562828, -3.5862305164,
         -3.1849975586, -3.8795177937, -3.2918174267, -4.1493740082],
        [-4.2809176445, -3.8154747486, -3.9485902786, -3.5474169254,
         -3.3101720810, -3.9628999233, -3.1462957859, -4.3412165642],
        [-3.9343552589, -3.7270662785, -3.8015532494, -3.7683753967,
         -3.1551008224, -3.9056501389, -3.3827788830, -4.2404427528],
        [-3.9076354504, -3.9564294815, -3.7914583683, -3.8463647366,
         -3.1535534859, -3.8839437962, -3.3489274979, -4.1176047325],
        [-3.4092268944, -3.6969037056, -3.3210926056, -3.0190851688,
         -2.5197367668, -3.4038391113, -2.6846501827, -3.1622867584],
        [-4.3182497025, -4.1536660194, -4.1363797188, -3.8312237263,
         -3.4353456497, -4.1218142509, -3.3725814819, -4.4551277161],
        [-3.8396859169, -3.7065541744, -3.7668292522, -3.8391516209,
         -3.1027283669, -3.8931672573, -3.4790892601, -4.2176709175],
        [-3.5017061234, -3.8813052177, -3.5182900429, -3.8348097801,
         -2.8906803131, -3.6051583290, -3.3228216171, -3.8034570217],
        [-3.4464774132, -3.8789443970, -3.3224365711, -3.0810019970,
         -2.5025038719, -3.5263288021, -2.5929825306, -3.0947396755],
        [-4.2273983955, -4.0144639015, -3.9377801418, -3.5958638191,
         -3.2787804604, -3.9937865734, -3.1052067280, -4.2226104736],
        [-3.3745789528, -3.4972953796, -3.3253111839, -3.6773264408,
         -2.6872353554, -3.5508143902, -3.4156613350, -3.7921562195],
        [-2.9271032810, -3.6714038849, -3.0703654289, -3.7321207523,
         -2.3481974602, -3.2294888496, -3.2276337147, -3.2842714787],
        [-3.0432860851, -3.6146023273, -3.0295176506, -2.7825140953,
         -2.1663827896, -3.1854391098, -2.2825937271, -2.5911545753],
        [-3.7111034393, -3.8651938438, -3.6059389114, -3.3985249996,
         -2.8822917938, -3.6716985703, -2.9334106445, -3.6565389633],
        [-2.9104354382, -3.4743690491, -3.0280134678, -3.5700728893,
         -2.3488564491, -3.1583023071, -3.1791927814, -3.3389244080],
        [-2.2210299969, -3.0240457058, -2.3275172710, -3.1043970585,
         -1.6053416729, -2.4352619648, -2.4883928299, -2.3523766994],
        [-2.7114429474, -3.4283807278, -2.6234874725, -2.4926660061,
         -1.8203262091, -2.7575194836, -2.1052958965, -2.1186287403],
        [-3.8372142315, -4.0330691338, -3.6443238258, -3.3687250614,
         -2.8806028366, -3.8182375431, -2.8009984493, -3.6179041862],
        [-2.9028546810, -3.5859816074, -2.9423685074, -3.5868747234,
         -2.2634677887, -3.1730742455, -3.1181771755, -3.2169914246],
        [-2.7047109604, -3.4103722572, -2.7948386669, -3.5056848526,
         -2.0805091858, -3.0460426807, -3.0462079048, -3.0641751289],
        [-1.8170139790, -3.0801727772, -1.9800815582, -2.1787593365,
         -1.1494998932, -1.8789745569, -1.3521165848, -1.1914081573],
        [-3.1940779686, -3.8410174847, -3.2339777946, -3.0565440655,
         -2.3353605270, -3.3992166519, -2.4079468250, -2.7703025341],
        [-2.5278239250, -3.2410411835, -2.4651215076, -3.1415879726,
         -1.8574959040, -2.6578209400, -2.5992963314, -2.6902410984],
        [-2.0230534077, -2.6885709763, -2.0396063328, -2.6587059498,
         -1.4356783628, -2.0917904377, -2.2107565403, -2.1953237057],
        [-2.7910606861, -3.6545739174, -2.8367061615, -2.7674615383,
         -1.9324414730, -2.9588494301, -2.1176555157, -2.1152365208],
        [-3.7650642395, -4.1399159431, -3.6829044819, -3.4644801617,
         -2.8268494606, -3.8292882442, -3.0056245327, -3.5794913769],
        [-2.4097881317, -3.1195113659, -2.4535489082, -3.1327230930,
         -1.7928742170, -2.5936071873, -2.6375653744, -2.6450741291],
        [-2.0059213638, -2.5362682343, -2.0935509205, -2.8857729435,
         -1.3728393316, -2.4956555367, -2.5161824226, -2.4067625999],
        [-2.9616751671, -3.5862176418, -2.9856770039, -2.7996230125,
         -2.1296710968, -3.1245079041, -2.2272682190, -2.5133960247],
        [-2.6331017017, -3.3296575546, -2.7890243530, -2.7522990704,
         -1.9952768087, -2.7798025608, -2.1912658215, -2.3751411438],
        [-1.9964694977, -2.6528348923, -2.0138716698, -2.6910498142,
         -1.3910135031, -2.1067337990, -2.2230579853, -2.1664700508]],
       grad_fn=<StackBackward>)

MU_DECODER
tensor([[ 7.4537128210e-01,  7.7623891830e-01,  7.6602262259e-01,
         -1.3993513584e-01,  8.1436938047e-01,  3.6314576864e-01,
          9.9604159594e-01,  7.9063363373e-02],
        [ 8.9660757780e-01,  8.9131146669e-01,  8.7661617994e-01,
         -3.6907706410e-02,  8.2370609045e-01,  5.3367781639e-01,
          9.7979170084e-01,  2.0927461982e-01],
        [ 8.5515135527e-01,  8.2580548525e-01,  8.0289411545e-01,
          4.1661325842e-02,  7.1721255779e-01,  5.4174715281e-01,
          8.7252455950e-01,  2.3034229875e-01],
        [ 8.6751633883e-01,  8.9080280066e-01,  8.8437432051e-01,
         -7.2503603995e-02,  8.5934364796e-01,  5.0289857388e-01,
          1.0022629499e+00,  1.8224236369e-01],
        [ 8.8764733076e-01,  9.1166216135e-01,  9.1891723871e-01,
         -7.7194452286e-02,  8.8277554512e-01,  5.2277541161e-01,
          1.0235692263e+00,  1.9542214274e-01],
        [ 8.9673584700e-01,  8.6170548201e-01,  8.2436698675e-01,
          1.4161440730e-01,  6.5297460556e-01,  6.3108932972e-01,
          7.9818397760e-01,  2.6648667455e-01],
        [ 8.8351935148e-01,  8.8325554132e-01,  8.5816204548e-01,
         -8.0196680501e-03,  7.9587674141e-01,  5.3751409054e-01,
          9.4600826502e-01,  2.0341506600e-01],
        [ 8.1397992373e-01,  8.2838648558e-01,  8.5030966997e-01,
         -8.8126033545e-02,  8.3583903313e-01,  4.5510044694e-01,
          9.9602252245e-01,  1.5417987108e-01],
        [ 6.1925458908e-01,  5.9385389090e-01,  5.2979224920e-01,
         -3.5912264138e-02,  4.9642848969e-01,  3.3578324318e-01,
          7.0547688007e-01,  7.8385539353e-02],
        [ 7.8108674288e-01,  6.6775715351e-01,  6.1285018921e-01,
          3.1544300914e-01,  3.6865007877e-01,  6.7531591654e-01,
          6.4135974646e-01,  3.5447835922e-01],
        [ 8.3815830946e-01,  7.8540831804e-01,  7.2268778086e-01,
          5.5315937847e-02,  6.7108374834e-01,  4.9458086491e-01,
          8.4645766020e-01,  2.3364454508e-01],
        [ 5.1406806707e-01,  5.2309536934e-01,  5.0807440281e-01,
         -1.9287040830e-01,  5.8621060848e-01,  1.5216127038e-01,
          8.3681327105e-01, -6.7267514765e-02],
        [ 4.1226357222e-01,  3.9164602757e-01,  3.0982518196e-01,
         -1.3932901621e-01,  3.6478865147e-01,  1.1988976598e-01,
          6.1140972376e-01, -7.3039494455e-02],
        [ 4.7460585833e-01,  4.3725985289e-01,  2.8978574276e-01,
          7.2090071440e-01, -2.4166262150e-01,  8.2806438208e-01,
          6.1434850097e-02,  6.0027962923e-01],
        [ 7.9462063313e-01,  7.3788356781e-01,  6.8965417147e-01,
          1.9477413595e-01,  5.3217446804e-01,  5.9116268158e-01,
          6.9491815567e-01,  2.9300740361e-01],
        [ 4.0884870291e-01,  4.2614352703e-01,  4.3263792992e-01,
         -2.7052247524e-01,  5.5374783278e-01,  2.5412589312e-02,
          8.5867840052e-01, -1.6064369678e-01],
        [ 1.1232741177e-01,  9.4285562634e-02, -2.8797790408e-02,
         -2.8413486481e-01,  1.1447374523e-01, -1.2736788392e-01,
          3.8964253664e-01, -2.7499958873e-01],
        [ 4.0609455109e-01,  4.1544246674e-01,  2.3380063474e-01,
          8.1382852793e-01, -3.6809879541e-01,  8.4226578474e-01,
         -9.0999066830e-02,  6.8012529612e-01],
        [ 7.5890117884e-01,  7.1640801430e-01,  6.4160078764e-01,
          2.2286985815e-01,  4.8637318611e-01,  5.6602346897e-01,
          6.3253593445e-01,  3.3189448714e-01],
        [ 1.8294855952e-02,  4.1917800903e-02,  1.9956126809e-02,
         -4.5387506485e-01,  3.2704287767e-01, -3.3966571093e-01,
          7.0831686258e-01, -4.3913036585e-01],
        [-3.8637369871e-01, -4.5166379213e-01, -4.6067780256e-01,
         -5.8770400286e-01, -3.9115294814e-02, -5.3764253855e-01,
          2.3831032217e-01, -6.1640590429e-01],
        [ 2.9319143295e-01,  2.8827565908e-01,  6.5197840333e-02,
          8.6398667097e-01, -5.9744858742e-01,  8.3846032619e-01,
         -2.4567192793e-01,  6.7352229357e-01],
        [ 5.2184075117e-01,  5.0640445948e-01,  3.5496354103e-01,
          3.6912393570e-01,  8.2235828042e-02,  5.9509968758e-01,
          2.9390001297e-01,  3.5063615441e-01],
        [-2.9110723734e-01, -3.1803846359e-01, -3.5329234600e-01,
         -5.5038774014e-01,  2.4381056428e-02, -5.0387620926e-01,
          3.2649433613e-01, -5.5301195383e-01],
        [-7.2178471088e-01, -8.0932444334e-01, -7.1661406755e-01,
         -6.4933156967e-01, -2.3751400411e-01, -5.6131917238e-01,
         -1.1282922328e-01, -7.3082917929e-01],
        [ 1.5074957907e-01,  9.7607210279e-02, -1.0853867233e-01,
          1.0877382755e+00, -1.0112128258e+00,  9.1293346882e-01,
         -5.6111729145e-01,  7.7494198084e-01],
        [ 5.4971641302e-01,  5.4886144400e-01,  4.1464430094e-01,
          5.7090681791e-01,  1.5562146902e-02,  7.1120524406e-01,
          1.9895754755e-01,  5.4621404409e-01],
        [-3.9366084337e-01, -4.4121766090e-01, -4.5996707678e-01,
         -5.5510425568e-01, -1.1685043573e-03, -5.3734993935e-01,
          2.4105130136e-01, -6.1151641607e-01],
        [-5.2276879549e-01, -5.4054433107e-01, -5.0406098366e-01,
         -6.1084896326e-01, -1.4487281442e-02, -6.0916495323e-01,
          2.0622767508e-01, -6.7569190264e-01],
        [-3.6969506741e-01, -4.5419234037e-01, -8.5087037086e-01,
          5.9729576111e-01, -1.3964147568e+00,  4.8319572210e-01,
         -1.0880144835e+00,  3.2983240485e-01],
        [ 2.5268208981e-01,  2.6167953014e-01,  2.8274446726e-02,
          7.4333733320e-01, -5.5177694559e-01,  7.4476432800e-01,
         -2.3858781159e-01,  5.9774476290e-01],
        [-5.3853386641e-01, -5.8221745491e-01, -5.9578448534e-01,
         -5.2591758966e-01, -1.6031323373e-01, -4.9402362108e-01,
         -3.4457042813e-02, -6.1920011044e-01],
        [-6.1508226395e-01, -7.0803725719e-01, -7.0008456707e-01,
         -5.6278920174e-01, -2.5570809841e-01, -5.3972440958e-01,
         -1.1528758705e-01, -6.3002502918e-01],
        [ 5.8707669377e-02,  1.8676459789e-02, -2.2664265335e-01,
          9.2681169510e-01, -1.0662351847e+00,  8.2683765888e-01,
         -5.8269971609e-01,  6.1325526237e-01],
        [ 4.6706598997e-01,  4.5936805010e-01,  3.1905025244e-01,
          7.0009756088e-01, -1.6421130300e-01,  8.0398738384e-01,
          7.9242959619e-02,  5.8904093504e-01],
        [-5.7732629776e-01, -6.3657850027e-01, -6.1326271296e-01,
         -5.9521275759e-01, -1.5290783346e-01, -5.5260998011e-01,
          1.1421605945e-02, -6.6122150421e-01],
        [-8.7108415365e-01, -7.0264959335e-01, -5.4622644186e-01,
         -7.2423720360e-01,  6.8168237805e-02, -7.8077685833e-01,
          1.3776995242e-01, -7.8278994560e-01],
        [ 2.2678442299e-01,  2.3853041232e-01, -1.2939646840e-02,
          7.8404253721e-01, -6.2871348858e-01,  7.6879101992e-01,
         -2.8913736343e-01,  5.9305030107e-01],
        [ 1.0347022116e-01,  1.4599119127e-01, -1.9325403869e-01,
          5.6087404490e-01, -6.3103687763e-01,  6.0161834955e-01,
         -3.6667895317e-01,  4.0997803211e-01],
        [-6.8918240070e-01, -7.5843018293e-01, -7.0153790712e-01,
         -6.0384953022e-01, -2.2374112904e-01, -5.8500319719e-01,
         -1.0391147435e-01, -6.8481385708e-01]], grad_fn=<StackBackward>)

The data I am using with this VAE was preprocessed with MaxMinScaler as implemented in scikit-learn to a range (-1, 1). The problem I see here is that logsvar_decoder became negative and large. I will investigate more about this problem. I suppose in principle I could “penalize” logsvar_decoder to remain small?