Masking with KLDivLoss

I want to compute a the KL divergence between 2 batches of distributions. x is my tensor with predicted distributions and target contains the target distributions. The shape of both x and target is (batch_size, max_dist_size). Each row in x and the target contains a distribution whose support is n <= max_dist_size. I also have a list of dist_size.

I am currently considering doing something like this:

criterion = nn.KLDivLoss(size_average=False)
l = 0.
for i in range(x.size(0)):
  l += criterion(x[i, :dist_size[i]].unsqueeze(0), target[i, :dist_size[i]].unsqueeze(0))

Is there a better way to use the dist_size list to mask and obtain the KL divergence over the entire batch?

Assuming dist_size is a LongTensor:

cols = torch.stack([torch.LongTensor(range(batch_size))] * max_dist_size, 0)
mask = (cols<(dist_size.unsqueeze(0).t())).float()
l = criterion(x*mask, target*mask)
1 Like