Correct implementation of VAE loss

I have some perplexities about the implementation of Variational autoencoder loss. This is the one I’ve been using so far:

def vae_loss(recon_loss, mu, logvar): 
    
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=1)
    
    return recon_loss + KLD

After having noticed problems in my loss convergence, even in simple tasks of 1d vectors reconstruction, I started googling around and I have find a variation of this:

def vae_loss(recon_loss, mu, logvar):
            
    KLD = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=1),dim=0)

    return recon_loss + KLD

With this second vae loss the performance increases noticeably, and it’s clearly visible in reconstructed vectors as well. What I’m doing with the second implementation is taking the average over batch of samples I guess.

My perplexity arises from the fact that I’ve found these two separate implementations on different blogposts and I don’t know which one is correct.

If you don’t reduce over the batch dimension in the loss function, you are bound to do so elsewhere because taking gradients only works for scalar losses.

Typically, we might think of the optimization as minimizing the expected loss function on the training distribution, so taking the sample mean would seem to be a good way to reduce the batch dimension.

Best regards

Thomas

Thanks for your answer but could you elaborate a bit more on that please? The second way I’ve proposed is wrong then?

No. But upon further thought you might have some broadcasting interaction between recon_loss and KLD.

If recon_loss is a scalar (maybe you can print shapes of these things to further the discussion):

  • And KLD is not (in the first variant) and you add the two, you are essentially adding recon_loss to every entry in KLD.
  • In the second part recon_loss is added once to the scalar KLD.

So in the first, you have a “batch_size” times higher weighting of recon_loss than in the second.

Now, if recon_loss is not a scalar, it would work the other way round, i.e. the second broadcasts KLD, so you again have that the weight is shifted between the two.
So at any rate, it looks to me that KLD is weighted much higher in the second than in the first regime relative to recon_loss.
Depending on your optimizer (yes for SGD, no for ADAM and co), the total loss scale might also have an influence.

I would recommend that to print KLD, recon_loss and the sum of the two to better see what’s going on.

Best regards

Thomas

Alright, as requested I’m including additional details of my problem. I’m trying to reconstruct simple 1d vectors with 8 dimensions (plus a 1 dimension label vector) with a conditional variational autoencoder model:

class CVAE(BaseModel):
    
    def __init__(self, in_size, target_size):
        
        super().__init__()
        
        #the input is concatenated to the target property
        self.encoder = nn.Sequential(
            nn.Linear(in_size + target_size, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Linear(128,latent_size*2),
        )
 
        self.decoder = nn.Sequential(
            nn.Linear(latent_size + target_size, 128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Linear(128,256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256,512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512,in_size),
        )

    def reparameterise(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def encode(self, x,cond):
      x = torch.cat([x,cond],dim=1)
      mu_logvar = self.encoder(x).view(-1, 2, latent_size)
      mu = mu_logvar[:, 0, :]
      logvar = mu_logvar[:, 1, :]
      return mu, logvar
  

    def decode(self, z):
      return self.decoder(z)

    def forward(self,x,cond):
        
        mu, logvar = self.encode(x,cond)
        z = self.reparameterise(mu, logvar)
        z = torch.cat([z,cond],dim=1)
        x_hat = self.decode(z)
        
        part1 = nn.ReLU()(x_hat[:,:4])
        part2 = nn.Softmax(dim=1)(x_hat[:,4:])
        
        x_hat = torch.cat([part1,part2],dim=1)
        
        return x_hat, mu, logvar

Additional method/class that I use to calculate recon_loss are the following:

def fit(self, dataloader, optimizer, criterion):
    self.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(dataloader.dataset)/dataloader.batch_size)):
        optimizer.zero_grad()
        reconstruction, mu, logvar = self.forward(data[0],data[1])
        loss = criterion(reconstruction, data[0])
        loss = vae_loss(loss, mu, logvar)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    train_loss = running_loss/len(dataloader.dataset)
    
    return train_loss

------------------------------------

class CustomLoss(_Loss):
    
  def __init__(self):
      
    super().__init__()
    
  def forward(self, input, target):
      
    """ loss function called at runtime """
    
    # Class 1
    class_1_loss = nn.MSELoss()(
        input[:,:4], 
        target[:,:4])
    
    # Class 2
    loss_cos = nn.CosineEmbeddingLoss()
    class_2_loss = loss_cos(
        input[:,4:], 
        target[:,4:],torch.ones(input.shape[0]))
    
    return class_1_loss + class_2_loss

Given that we take a batch of data in DataLoader and calculating recon_loss:

for data in data_loader:
      break

criterion= CustomLoss()

recon, mu, logvar = model(data[0],data[1])

recon_loss = criterion(recon, data[0])

where

recon_loss
Out[18]: tensor(1281.3059, grad_fn=<AddBackward0>)

next, in `vae_loss() we compute

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=1)

KLD
Out[23]: 
tensor([10.0125,  9.2514,  8.6930, 11.5028,  6.2790, 11.4213, 11.1161,  8.7398,
         8.7730, 11.4224, 10.7966,  8.1383,  7.9794,  8.4988,  8.1407,  7.8110,
         7.8110,  7.9266,  7.5567,  7.7040,  9.9480,  7.6494, 10.5388,  9.4197,
        10.1101,  8.4504,  8.7865,  8.8465, 10.7121,  8.4979,  8.6045,  8.3936,
         8.5527,  7.2297,  8.4476, 10.8846, 10.4522, 11.5168,  7.7652,  9.0135,
         8.9415, 10.0520,  9.4360, 10.9811,  9.3636, 10.0256,  7.2043,  7.6451,
         7.7242,  7.6227, 10.8506,  8.6256,  7.5179, 10.1579,  8.7330, 10.3774,
         9.8613,  8.9309, 10.0385,  9.0261,  9.4356,  9.6858, 10.1660,  8.4929,
         8.7968,  7.5675,  8.2790,  8.7619,  9.4661, 10.3707, 10.6991, 10.1204,
        11.2257, 11.0965,  8.7320, 10.7721,  9.4106,  9.6219,  8.5730, 12.0483,
         6.5600, 10.1521, 10.1500, 10.1347, 10.1345, 10.1373,  8.4438,  6.3176,
         8.5711,  8.7008,  9.7572, 11.4712, 10.6697, 11.0056, 10.6899, 10.3529,
         8.3172, 10.5426,  8.4198, 10.1392, 11.1788,  8.7461,  8.5806,  7.9725,
         8.7498, 10.9897, 11.0135,  8.6260,  9.3328,  9.3445,  9.6393,  6.3542,
         9.4578,  9.5768, 11.3704,  9.4054, 11.4350,  9.7938,  9.7804, 10.1121,
        11.1820,  6.0557, 10.4342,  7.9478,  6.3706, 11.0652,  9.2101,  8.6107],
       grad_fn=<MulBackward0>)

where I’ve just realized that if we do not compute a torch.mean over these values we won’t be able to perform the final operation into vae_loss() : recon_loss + KLD
So taking the mean could be reasonable here but I still don’t understand if it makes sense from a more conceptual standopoint.

Taking the mean can be seen as an estimator for the expectation. I’d say it does make sense here. (That is in contrast to e.g. taking the mean over sequences, which could be much more touchy, because summation would give you the likelihood of the sequence under the data.)
And you really want recon_loss and KLD to either be both reduced or both un-reduced to not accidentally introduce an imbalance in the loss weighting, so taking the mean of KLD first is a good option in my opinion.

Best regards

Thomas