I am trying to use the KL loss for a VAE. The latents, i.e. output of the encoder is of dimension
z =[Batchsize,posterior,height,width], e.g. z =[2,64,4,4].
and I want to determine the KL with a uniform prior.
Now I am a little confused about what to sum and what to average. The docs state that one should use ‘batchmean’ for a mathematically consistent calculation across samples, i.e. batches. However, in my case, I also have spatial dimensions height and width and for each pixel I want to evaluate the KL divergence separately and only than average. So is it correct to use
log_p = F.log_softmax(z, dim = 1) # turn latents into probabilities log_q = torch.log(torch.ones_like(z) / z.shape) # uniform prior kl_loss = F.kl_div(log_p, log_q, reduction='batchmean', log_target = True)
or the manual
p = F.softmax(z, dim=1) # turn latents into probabilities q = torch.ones_like(z) / z.shape) # uniform prior log_ratio = torch.log(p / q) kl_loss = torch.sum(p * log_ratio, dim=1).mean() # sum each kl over probability dim
which does give a different result but intuitively feels correct. First a summation for the KL over the channel/probability dimension followed by averaging over both the batch and the spatial extend dimensions.
Can someone help me out here please?