@tom This is the loss function code that is equivalent to https://github.com/y0ast/Variational-Autoencoder/blob/master/VAE.py#L118.
def VAELoss(
outputs=None,
targets=None,
mus_latent=None,
logvars_latent=None,
mus_decoder=None,
logvars_decoder=None,
annealing=None,
multivariate=None,
latent=None,
input_dimension=None,
):
"""Variational Autoencoder loss function
Parameters
----------
outputs : tensor
Outputs of the model.
targets : tensor
Expected value of outputs.
mus_latent : tensor
Mean values of distribution.
logvars_latent : tensor
Logarithm of the variance.
multivariate : bool
If multivariate is set to True we treat the distribution as a
multivariate Gaussian distribution otherwise we use Bernoulli.
annealing : float
Contribution of distance loss function to total loss.
latent : tensor, optional
The latent space tensor.
input_dimension : int, optional
Input's dimension.
Returns
-------
loss : tensor
The value of the loss function.
"""
loss = []
dim = 1
if multivariate:
# loss_rec = LOG_2_PI + logvar_x + (x - mu_x)**2 / (2*torch.exp(logvar_x))
loss_rec = -torch.sum(
(-0.5 * np.log(2.0 * np.pi))
+ (-0.5 * logvars_decoder)
+ ((-0.5 / torch.exp(logvars_decoder)) * (targets - mus_decoder) ** 2.0),
dim=dim,
)
else:
loss_rec = torch.nn.functional.binary_cross_entropy(
outputs, targets, reduction="sum"
)
loss_rec *= input_dimension
loss.append(loss_rec)
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld = (
-0.5
* torch.sum(
1 + logvars_latent - mus_latent.pow(2) - logvars_latent.exp(), dim=dim
)
* annealing
)
loss.append(kld)
if latent is not None:
activation_reg = torch.mean(torch.pow(latent, 2), dim=dim)
loss.append(activation_reg)
# Mini-batch mean
loss = torch.mean(torch.stack(loss))
return loss
The model can be seen here: https://github.com/muammar/ml4chem/blob/master/ml4chem/models/autoencoders.py#L361
What I did was to print logvars_decoder
, and mus_decoder
when the loss becomes negative:
LOGVAR_DECODER
tensor([[-4.0889458656, -3.5250387192, -3.8208389282, -3.5385358334,
-3.1981320381, -3.9037659168, -3.2332408428, -4.3280816078],
[-4.3436975479, -3.7619643211, -4.0336480141, -3.6318600178,
-3.4294614792, -4.0175385475, -3.2871749401, -4.5249586105],
[-4.3395757675, -3.9168510437, -4.0671305656, -3.7074320316,
-3.4387311935, -4.0541105270, -3.3110878468, -4.5024285316],
[-4.1954979897, -3.6041107178, -3.9152162075, -3.5348732471,
-3.3275499344, -3.9352111816, -3.2573788166, -4.4472432137],
[-4.2599620819, -3.6669955254, -3.9934899807, -3.6058964729,
-3.3935651779, -3.9977569580, -3.3050155640, -4.5190782547],
[-4.3928508759, -3.9678094387, -4.0128798485, -3.6809661388,
-3.4883437157, -3.9990940094, -3.3398213387, -4.5465860367],
[-4.2483749390, -3.6474187374, -3.8801121712, -3.4824745655,
-3.3359556198, -3.8939044476, -3.2095971107, -4.4221296310],
[-4.1744241714, -3.6982641220, -3.9988830090, -3.6817216873,
-3.3437509537, -4.0203957558, -3.3463952541, -4.4784379005],
[-3.9882695675, -3.9037389755, -3.9318742752, -3.8031373024,
-3.2512860298, -3.9384653568, -3.3627109528, -4.2665996552],
[-4.0217909813, -3.9362194538, -3.9807562828, -3.5862305164,
-3.1849975586, -3.8795177937, -3.2918174267, -4.1493740082],
[-4.2809176445, -3.8154747486, -3.9485902786, -3.5474169254,
-3.3101720810, -3.9628999233, -3.1462957859, -4.3412165642],
[-3.9343552589, -3.7270662785, -3.8015532494, -3.7683753967,
-3.1551008224, -3.9056501389, -3.3827788830, -4.2404427528],
[-3.9076354504, -3.9564294815, -3.7914583683, -3.8463647366,
-3.1535534859, -3.8839437962, -3.3489274979, -4.1176047325],
[-3.4092268944, -3.6969037056, -3.3210926056, -3.0190851688,
-2.5197367668, -3.4038391113, -2.6846501827, -3.1622867584],
[-4.3182497025, -4.1536660194, -4.1363797188, -3.8312237263,
-3.4353456497, -4.1218142509, -3.3725814819, -4.4551277161],
[-3.8396859169, -3.7065541744, -3.7668292522, -3.8391516209,
-3.1027283669, -3.8931672573, -3.4790892601, -4.2176709175],
[-3.5017061234, -3.8813052177, -3.5182900429, -3.8348097801,
-2.8906803131, -3.6051583290, -3.3228216171, -3.8034570217],
[-3.4464774132, -3.8789443970, -3.3224365711, -3.0810019970,
-2.5025038719, -3.5263288021, -2.5929825306, -3.0947396755],
[-4.2273983955, -4.0144639015, -3.9377801418, -3.5958638191,
-3.2787804604, -3.9937865734, -3.1052067280, -4.2226104736],
[-3.3745789528, -3.4972953796, -3.3253111839, -3.6773264408,
-2.6872353554, -3.5508143902, -3.4156613350, -3.7921562195],
[-2.9271032810, -3.6714038849, -3.0703654289, -3.7321207523,
-2.3481974602, -3.2294888496, -3.2276337147, -3.2842714787],
[-3.0432860851, -3.6146023273, -3.0295176506, -2.7825140953,
-2.1663827896, -3.1854391098, -2.2825937271, -2.5911545753],
[-3.7111034393, -3.8651938438, -3.6059389114, -3.3985249996,
-2.8822917938, -3.6716985703, -2.9334106445, -3.6565389633],
[-2.9104354382, -3.4743690491, -3.0280134678, -3.5700728893,
-2.3488564491, -3.1583023071, -3.1791927814, -3.3389244080],
[-2.2210299969, -3.0240457058, -2.3275172710, -3.1043970585,
-1.6053416729, -2.4352619648, -2.4883928299, -2.3523766994],
[-2.7114429474, -3.4283807278, -2.6234874725, -2.4926660061,
-1.8203262091, -2.7575194836, -2.1052958965, -2.1186287403],
[-3.8372142315, -4.0330691338, -3.6443238258, -3.3687250614,
-2.8806028366, -3.8182375431, -2.8009984493, -3.6179041862],
[-2.9028546810, -3.5859816074, -2.9423685074, -3.5868747234,
-2.2634677887, -3.1730742455, -3.1181771755, -3.2169914246],
[-2.7047109604, -3.4103722572, -2.7948386669, -3.5056848526,
-2.0805091858, -3.0460426807, -3.0462079048, -3.0641751289],
[-1.8170139790, -3.0801727772, -1.9800815582, -2.1787593365,
-1.1494998932, -1.8789745569, -1.3521165848, -1.1914081573],
[-3.1940779686, -3.8410174847, -3.2339777946, -3.0565440655,
-2.3353605270, -3.3992166519, -2.4079468250, -2.7703025341],
[-2.5278239250, -3.2410411835, -2.4651215076, -3.1415879726,
-1.8574959040, -2.6578209400, -2.5992963314, -2.6902410984],
[-2.0230534077, -2.6885709763, -2.0396063328, -2.6587059498,
-1.4356783628, -2.0917904377, -2.2107565403, -2.1953237057],
[-2.7910606861, -3.6545739174, -2.8367061615, -2.7674615383,
-1.9324414730, -2.9588494301, -2.1176555157, -2.1152365208],
[-3.7650642395, -4.1399159431, -3.6829044819, -3.4644801617,
-2.8268494606, -3.8292882442, -3.0056245327, -3.5794913769],
[-2.4097881317, -3.1195113659, -2.4535489082, -3.1327230930,
-1.7928742170, -2.5936071873, -2.6375653744, -2.6450741291],
[-2.0059213638, -2.5362682343, -2.0935509205, -2.8857729435,
-1.3728393316, -2.4956555367, -2.5161824226, -2.4067625999],
[-2.9616751671, -3.5862176418, -2.9856770039, -2.7996230125,
-2.1296710968, -3.1245079041, -2.2272682190, -2.5133960247],
[-2.6331017017, -3.3296575546, -2.7890243530, -2.7522990704,
-1.9952768087, -2.7798025608, -2.1912658215, -2.3751411438],
[-1.9964694977, -2.6528348923, -2.0138716698, -2.6910498142,
-1.3910135031, -2.1067337990, -2.2230579853, -2.1664700508]],
grad_fn=<StackBackward>)
MU_DECODER
tensor([[ 7.4537128210e-01, 7.7623891830e-01, 7.6602262259e-01,
-1.3993513584e-01, 8.1436938047e-01, 3.6314576864e-01,
9.9604159594e-01, 7.9063363373e-02],
[ 8.9660757780e-01, 8.9131146669e-01, 8.7661617994e-01,
-3.6907706410e-02, 8.2370609045e-01, 5.3367781639e-01,
9.7979170084e-01, 2.0927461982e-01],
[ 8.5515135527e-01, 8.2580548525e-01, 8.0289411545e-01,
4.1661325842e-02, 7.1721255779e-01, 5.4174715281e-01,
8.7252455950e-01, 2.3034229875e-01],
[ 8.6751633883e-01, 8.9080280066e-01, 8.8437432051e-01,
-7.2503603995e-02, 8.5934364796e-01, 5.0289857388e-01,
1.0022629499e+00, 1.8224236369e-01],
[ 8.8764733076e-01, 9.1166216135e-01, 9.1891723871e-01,
-7.7194452286e-02, 8.8277554512e-01, 5.2277541161e-01,
1.0235692263e+00, 1.9542214274e-01],
[ 8.9673584700e-01, 8.6170548201e-01, 8.2436698675e-01,
1.4161440730e-01, 6.5297460556e-01, 6.3108932972e-01,
7.9818397760e-01, 2.6648667455e-01],
[ 8.8351935148e-01, 8.8325554132e-01, 8.5816204548e-01,
-8.0196680501e-03, 7.9587674141e-01, 5.3751409054e-01,
9.4600826502e-01, 2.0341506600e-01],
[ 8.1397992373e-01, 8.2838648558e-01, 8.5030966997e-01,
-8.8126033545e-02, 8.3583903313e-01, 4.5510044694e-01,
9.9602252245e-01, 1.5417987108e-01],
[ 6.1925458908e-01, 5.9385389090e-01, 5.2979224920e-01,
-3.5912264138e-02, 4.9642848969e-01, 3.3578324318e-01,
7.0547688007e-01, 7.8385539353e-02],
[ 7.8108674288e-01, 6.6775715351e-01, 6.1285018921e-01,
3.1544300914e-01, 3.6865007877e-01, 6.7531591654e-01,
6.4135974646e-01, 3.5447835922e-01],
[ 8.3815830946e-01, 7.8540831804e-01, 7.2268778086e-01,
5.5315937847e-02, 6.7108374834e-01, 4.9458086491e-01,
8.4645766020e-01, 2.3364454508e-01],
[ 5.1406806707e-01, 5.2309536934e-01, 5.0807440281e-01,
-1.9287040830e-01, 5.8621060848e-01, 1.5216127038e-01,
8.3681327105e-01, -6.7267514765e-02],
[ 4.1226357222e-01, 3.9164602757e-01, 3.0982518196e-01,
-1.3932901621e-01, 3.6478865147e-01, 1.1988976598e-01,
6.1140972376e-01, -7.3039494455e-02],
[ 4.7460585833e-01, 4.3725985289e-01, 2.8978574276e-01,
7.2090071440e-01, -2.4166262150e-01, 8.2806438208e-01,
6.1434850097e-02, 6.0027962923e-01],
[ 7.9462063313e-01, 7.3788356781e-01, 6.8965417147e-01,
1.9477413595e-01, 5.3217446804e-01, 5.9116268158e-01,
6.9491815567e-01, 2.9300740361e-01],
[ 4.0884870291e-01, 4.2614352703e-01, 4.3263792992e-01,
-2.7052247524e-01, 5.5374783278e-01, 2.5412589312e-02,
8.5867840052e-01, -1.6064369678e-01],
[ 1.1232741177e-01, 9.4285562634e-02, -2.8797790408e-02,
-2.8413486481e-01, 1.1447374523e-01, -1.2736788392e-01,
3.8964253664e-01, -2.7499958873e-01],
[ 4.0609455109e-01, 4.1544246674e-01, 2.3380063474e-01,
8.1382852793e-01, -3.6809879541e-01, 8.4226578474e-01,
-9.0999066830e-02, 6.8012529612e-01],
[ 7.5890117884e-01, 7.1640801430e-01, 6.4160078764e-01,
2.2286985815e-01, 4.8637318611e-01, 5.6602346897e-01,
6.3253593445e-01, 3.3189448714e-01],
[ 1.8294855952e-02, 4.1917800903e-02, 1.9956126809e-02,
-4.5387506485e-01, 3.2704287767e-01, -3.3966571093e-01,
7.0831686258e-01, -4.3913036585e-01],
[-3.8637369871e-01, -4.5166379213e-01, -4.6067780256e-01,
-5.8770400286e-01, -3.9115294814e-02, -5.3764253855e-01,
2.3831032217e-01, -6.1640590429e-01],
[ 2.9319143295e-01, 2.8827565908e-01, 6.5197840333e-02,
8.6398667097e-01, -5.9744858742e-01, 8.3846032619e-01,
-2.4567192793e-01, 6.7352229357e-01],
[ 5.2184075117e-01, 5.0640445948e-01, 3.5496354103e-01,
3.6912393570e-01, 8.2235828042e-02, 5.9509968758e-01,
2.9390001297e-01, 3.5063615441e-01],
[-2.9110723734e-01, -3.1803846359e-01, -3.5329234600e-01,
-5.5038774014e-01, 2.4381056428e-02, -5.0387620926e-01,
3.2649433613e-01, -5.5301195383e-01],
[-7.2178471088e-01, -8.0932444334e-01, -7.1661406755e-01,
-6.4933156967e-01, -2.3751400411e-01, -5.6131917238e-01,
-1.1282922328e-01, -7.3082917929e-01],
[ 1.5074957907e-01, 9.7607210279e-02, -1.0853867233e-01,
1.0877382755e+00, -1.0112128258e+00, 9.1293346882e-01,
-5.6111729145e-01, 7.7494198084e-01],
[ 5.4971641302e-01, 5.4886144400e-01, 4.1464430094e-01,
5.7090681791e-01, 1.5562146902e-02, 7.1120524406e-01,
1.9895754755e-01, 5.4621404409e-01],
[-3.9366084337e-01, -4.4121766090e-01, -4.5996707678e-01,
-5.5510425568e-01, -1.1685043573e-03, -5.3734993935e-01,
2.4105130136e-01, -6.1151641607e-01],
[-5.2276879549e-01, -5.4054433107e-01, -5.0406098366e-01,
-6.1084896326e-01, -1.4487281442e-02, -6.0916495323e-01,
2.0622767508e-01, -6.7569190264e-01],
[-3.6969506741e-01, -4.5419234037e-01, -8.5087037086e-01,
5.9729576111e-01, -1.3964147568e+00, 4.8319572210e-01,
-1.0880144835e+00, 3.2983240485e-01],
[ 2.5268208981e-01, 2.6167953014e-01, 2.8274446726e-02,
7.4333733320e-01, -5.5177694559e-01, 7.4476432800e-01,
-2.3858781159e-01, 5.9774476290e-01],
[-5.3853386641e-01, -5.8221745491e-01, -5.9578448534e-01,
-5.2591758966e-01, -1.6031323373e-01, -4.9402362108e-01,
-3.4457042813e-02, -6.1920011044e-01],
[-6.1508226395e-01, -7.0803725719e-01, -7.0008456707e-01,
-5.6278920174e-01, -2.5570809841e-01, -5.3972440958e-01,
-1.1528758705e-01, -6.3002502918e-01],
[ 5.8707669377e-02, 1.8676459789e-02, -2.2664265335e-01,
9.2681169510e-01, -1.0662351847e+00, 8.2683765888e-01,
-5.8269971609e-01, 6.1325526237e-01],
[ 4.6706598997e-01, 4.5936805010e-01, 3.1905025244e-01,
7.0009756088e-01, -1.6421130300e-01, 8.0398738384e-01,
7.9242959619e-02, 5.8904093504e-01],
[-5.7732629776e-01, -6.3657850027e-01, -6.1326271296e-01,
-5.9521275759e-01, -1.5290783346e-01, -5.5260998011e-01,
1.1421605945e-02, -6.6122150421e-01],
[-8.7108415365e-01, -7.0264959335e-01, -5.4622644186e-01,
-7.2423720360e-01, 6.8168237805e-02, -7.8077685833e-01,
1.3776995242e-01, -7.8278994560e-01],
[ 2.2678442299e-01, 2.3853041232e-01, -1.2939646840e-02,
7.8404253721e-01, -6.2871348858e-01, 7.6879101992e-01,
-2.8913736343e-01, 5.9305030107e-01],
[ 1.0347022116e-01, 1.4599119127e-01, -1.9325403869e-01,
5.6087404490e-01, -6.3103687763e-01, 6.0161834955e-01,
-3.6667895317e-01, 4.0997803211e-01],
[-6.8918240070e-01, -7.5843018293e-01, -7.0153790712e-01,
-6.0384953022e-01, -2.2374112904e-01, -5.8500319719e-01,
-1.0391147435e-01, -6.8481385708e-01]], grad_fn=<StackBackward>)
The data I am using with this VAE was preprocessed with MaxMinScaler as implemented in scikit-learn to a range (-1, 1). The problem I see here is that logsvar_decoder
became negative and large. I will investigate more about this problem. I suppose in principle I could “penalize” logsvar_decoder
to remain small?