KL loss on feature maps

Hi,

I am trying to minimize the distance between two feature maps using KL Divergence. The output file is of shape in both the case are : (16, 1, 8, 224, 224)

One map is the prediction on RGB clip and another one is the flipped version of the same clip. I’m getting a NaN loss.

output = model(data, label)  # 16, 1, 8, 224, 224
output = torch.flip(output, [4])
output = torch.mean(output, 2) # 16, 1, 224, 224
flip_op = torch.mean(flip_op, 2)
output = torch.squeeze(output, 1). # 16, 224, 224
flip_op = torch.squeeze(flip_op, 1)
# CONS_LOSS
output += 1e-7
flip_op += 1e-7
cons_loss_a = consistency_criterion(flip_op.log(), output.detach()).sum(-1).mean()

I’m not sure how flop_op is calculated, but did you make sure that the log() operation doesn’t produce invalid values?

I was doing it the wrong way. It was not a probability. It was logits.
I get flip_op the same way as output, just the data is horizontally flipped.

To apply KL on segmentation I tried this:

conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda()
output = model(data, label)     # 16, 1, 8, 224, 224  
flip_op = model(flipped_data, label)   # 16, 1, 8, 224, 224
output = torch.flip(output, [4])
output = torch.mean(output, 2) # 16, 1, 224, 224
flip_op = torch.mean(flip_op, 2) # 16, 1, 224, 224
output = torch.squeeze(output, 1) # 16, 224, 224
flip_op = torch.squeeze(flip_op, 1) # 16, 224, 224
# CONS_LOSS
output += 1e-7
flip_op += 1e-7
cons_loss_a = conf_consistency_criterion(F.logsigmoid(output), F.sigmoid(flip_op)).sum(-1).mean()

I’m getting a negative output. So, this is wrong somewhere. KL is a measure of how close two probability distributions.

I also tried this:

conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda()
output = model(data, label)     # 16, 1, 8, 224, 224  
flip_op = model(flipped_data, label)   # 16, 1, 8, 224, 224
output = torch.flip(output, [4])    
# CONS_LOSS
output += 1e-7
flip_op += 1e-7
cons_loss_a = conf_consistency_criterion(F.logsigmoid(output), F.sigmoid(flip_op)).sum(-1).mean()

This gives me positive value. However, the value suddenly drops from 20 to 0.006 in the second epoch. I don’t get what’s the right approach to do so?

I’m not getting a negative loss in any of the two approaches:

conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda()
output = torch.randn(16, 1, 8, 224, 224  )
flip_op = torch.randn(16, 1, 8, 224, 224)
output = torch.flip(output, [4])
output = torch.mean(output, 2) # 16, 1, 224, 224
flip_op = torch.mean(flip_op, 2) # 16, 1, 224, 224
output = torch.squeeze(output, 1) # 16, 224, 224
flip_op = torch.squeeze(flip_op, 1) # 16, 224, 224
# CONS_LOSS
output += 1e-7
flip_op += 1e-7
cons_loss_a = conf_consistency_criterion(F.logsigmoid(output), torch.sigmoid(flip_op)).sum(-1).mean()
print(cons_loss_a)


output = torch.randn(16, 1, 8, 224, 224)
flip_op = torch.randn(16, 1, 8, 224, 224)
output = torch.flip(output, [4])    
# CONS_LOSS
output += 1e-7
flip_op += 1e-7
cons_loss_a = conf_consistency_criterion(F.logsigmoid(output), F.sigmoid(flip_op)).sum(-1).mean()
print(cons_loss_a)

so don’t know what might be causing this issue.

So, either approach can be used to compute the JSD loss?
Max min values for both of my segmentation maps range between 2 to -2.
My ground truth map is binary 0/1.

Yes, I think so. Based on the docs the inputs should be given as log-probabilities and the targets as probabilities by default, which seems to be the case for you.
I’m not familiar with your use case so cannot comment, if (log)sigmoid or (log)softmax should be used.

I have a binary output. So, sigmoid is the right choice, isn’t that so?