Custom loss function IoU is not differentiable. Can you create a differentiable IoU loss function for ML?

Thank you for reading my post.
I’m a college student, and currently developing the peak detection algorithm using CNN to determine the ideal convolution kernel which is representable as the ideal mother wavelet function that will maximize the peak detection accuracy.

I’ve tried to create my own IoU loss function for the CNN training model, but I failed.
My own loss function is described as below.


'''
1D intersection over union loss function class
'''
class IoU(nn.Module):
  def __init__(self, thresh: float = 0.5):
    super().__init__()
    self.thresh = thresh

  def forward(self, inputs: torch.Tensor, targets:torch.Tensor, weights: Optional[torch.Tensor] = None, smooth: float = 0.0) -> Tensor:

    inputs = torch.where(inputs < self.thresh, 0, 1)
    batch_size = targets.shape[0]

    intersect = torch.logical_and(inputs, targets)
    intersect = intersect.view(batch_size, -1).sum(-1)

    union = torch.logical_and(inputs, targets)
    union = union.view(batch_size, -1).sum(-1)

    IoU = (intersect + smooth) / (union + smooth)
    IoU = IoU.mean()

    return IoU

and I tried to test whether this works or not by this simple model described below.


x = torch.tensor(0, 1001, 256) # e.g. [0, 200, 30, 1000, ...]
true = torch.tensor(0, 2, 256) # e.g. [0, 1, 1, 0, 1, ...]

model = nn.Linear(256, 256)
criterion = IoU()

output = model(x)
loss = criterion(output, true)
loss.backward() # I'm stuck on here, cause my loss func IoU is not differentiable

print(f"loss ={loss}")
print(f"model weight: {model.weight.grad}")
print(f"model params: [x.grad for x in {model.parameters()]")

And the output on the terminal is RuntimeError: element 0 of variables does not require grad and does not have a grad_fn

This project is the time ever for me to use PyTorch, so I didn’t know what it meant at the first glance, but after my quick research, I figured out why this loss function fails (I’m not sure this is correct though)

my loss function IoU is not differentiable.

and

This is where the chain rule of this loss function break.

IoU = torch.nan_to_num(IoU)
IoU = IoU.mean()

Soon after I noticed this, I took a deeper look at the GitHub or stack overflow to find any other differentiable IoU loss function, but I’m still not sure how to create a differentiable IoU loss function (especially for 1D data).

If you have any experiences or insights around what I’m stuck on right now, please give me any instructions. Any advice will be welcomed and I hope that I can build my own IoU function for CNN model at the end.

Thank you

Hi Passive!

As you have deduced, your IoU loss criterion is not (usefully) differentiable.

First, to clear up some context: Based on your thresholding value of
thresh = 0.5, let me assume that inputs (the output of your model)
are to be thought of as probabilities that range from zero to one (and
represent the probability of the “pixel” in question being in the predicted
“object” whose intersection-over-union you wish to calculate). Note, this
is not consistent with your model being just a Linear whose outputs
range from -inf to +inf and would be thought of as the logits that
correspond to the aforementioned probabilities. If you have such logits
in your real use case, you can convert them to probabilities by passing
them through sigmoid().

You threshold your probabilities to produce “hard,” yes-no predictions for
whether a given pixel is predicted to be in the object, and then use these
hard predictions in your IoU computation. Although mathematically
differentiable almost everywhere (everywhere except x = thresh, where
the derivative is undefined), the derivative is always zero, so it gives your
optimization no useful information. (As an aside, as written, you cannot
backpropagate through your thresholding. But even if you could, you would
just get zero.)

Calculate instead a “soft,” probabilistic IoU. Replace your logical_and()
with multiplication:

intersect = (inputs * targets).view (batch_size, -1).sum (-1)

(logical_and() is precisely the regular product when all the values are
either zero or one.)

Note, you most likely have a typo in your expression for union and
presumably meant logical_or().

Off the top of my head, I would replace the logical_or() that you (should)
have in the expression for union with max:

intersect = torch.max (inputs, targets).view (batch_size, -1).sum (-1)

(Again, max() is the same as logical_or() when all of the values are
zero or one.)

If you want to threshold inputs, rather than just use the “raw” probabilities,
you can use a “soft,” differentiable threshold:

thresholded_inputs = inputs**alpha / (inputs**alpha + (1 - inputs)**alpha)

(with alpha >= 0.0).

alpha is a parameter that “sharpens” the thresholding: As alpha ranges
from 1.0 to +inf, you interpolate between using the raw, unthresholded
probabilities to fully thresholding them to be exactly either zero or one.

(If you are predicting logits, rather than probabilities, just multiply the
logits by alpha before passing them through sigmoid() to convert
them to probabilities.)

(As an aside, using IoU suggests that you are trying to predict which
“pixels” are in certain “objects.” This is a binary segmentation problem
problem, which is to say a per-pixel binary classification problem. You
might consider using BCEWithLogitsLoss as your loss function, with
pos_weight, if necessary, to compensate for foreground-background
class imbalance. If you feel that using IoU would add value as the loss
function, you might consider augmenting BCEWithLogitsLoss with IoU,
or (a differentiable version of) its cousin, Dice loss.)

Good luck.

K. Frank

2 Likes

Hi! K. Frank
Thank you for your decent advice, it helps a lot.
I should’ve remembered the operation matrix * matrix doesn’t break the chain rule of differentiation :joy:
I rewrote my model and loss function to make them differentiable and ready to run state according to your series of advice, and I succeeded to “execute” my training without an error!
This is a huge advancement for me (plz allow me to say thank you again), though, I faced another problem (it might be a easy peasy to fix for knowledgeable person though) which is,

the weights of the model doesn’t upgraded through its learning over epochs.

This must be another topic, so I posted a new topic The parameters of the model with custom loss function doesn’t upgraded thorough its learning over epochs - discuss.pytorch.org

Thank you!
passiveradio