I was doing it the wrong way. It was not a probability. It was logits.
I get flip_op the same way as output, just the data is horizontally flipped.
To apply KL on segmentation I tried this:
conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda()
output = model(data, label) # 16, 1, 8, 224, 224
flip_op = model(flipped_data, label) # 16, 1, 8, 224, 224
output = torch.flip(output, [4])
output = torch.mean(output, 2) # 16, 1, 224, 224
flip_op = torch.mean(flip_op, 2) # 16, 1, 224, 224
output = torch.squeeze(output, 1) # 16, 224, 224
flip_op = torch.squeeze(flip_op, 1) # 16, 224, 224
# CONS_LOSS
output += 1e-7
flip_op += 1e-7
cons_loss_a = conf_consistency_criterion(F.logsigmoid(output), F.sigmoid(flip_op)).sum(-1).mean()
I’m getting a negative output. So, this is wrong somewhere. KL is a measure of how close two probability distributions.
I also tried this:
conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda()
output = model(data, label) # 16, 1, 8, 224, 224
flip_op = model(flipped_data, label) # 16, 1, 8, 224, 224
output = torch.flip(output, [4])
# CONS_LOSS
output += 1e-7
flip_op += 1e-7
cons_loss_a = conf_consistency_criterion(F.logsigmoid(output), F.sigmoid(flip_op)).sum(-1).mean()
This gives me positive value. However, the value suddenly drops from 20 to 0.006 in the second epoch. I don’t get what’s the right approach to do so?