Custom loss function error

Hello everyone, I am currently trying to implement custom dice loss function for my semantic segmentation model.
The problem is that I keep getting an error as below:
’RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn’
I know that this error is caused due to using torch.argmax(pred,dim=1),
as argmax function is not differentiable. However, without using argmax, I can not think of any other way as for my dice loss function I need to compare pred and target values as below:
with torch.enable_grad():
num= 2 * torch.sum(pred * target)
den = torch.sum(pred + target)
return 1 - torch.true_divide((num+ 1), (den+ 1))
Could somebody please help me with this problem?
Thank you!

My whole function looks as below:
def custom_loss(output,labels):
** preds = torch.nn.LogSoftmax(dim=1)(output)**
** preds = torch.argmax(preds,dim=1)**
** print(preds)**
** loss = dice_loss(preds,labels)**
** return loss**

def dice_loss(pred,target):
** with torch.enable_grad():**
** num= 2 * torch.sum(pred * target)**
** den = torch.sum(pred + target)**
** return 1 - torch.true_divide((num+ 1), (den+ 1))**

Don’t use argmax but softmax and 1 component of the result.
This will create a “soft dice”, which is actually differentiable (argmax does not have meaningful gradients).

Hi, thank you for your reply!
I am not sure I fully understand your answer.
When I do nn.Softmax() then I get vectors including 2 different probabilities since it is binary segmentation model.
Softmax function solely does not pick the highest probability index to compare with my target values, so surely I have to use functions like argmax after softmax function?

I don’t think I can explain it much better than wikipedia’s article on soft max, but
torch.nn.functional.softmax(score, dim=1)[:, 1] is actually a smooth approximation of torch.argmax(score, dim=1). That is just what you need and the reason almost any classification works with argmax. We write a bit about that in section 7.2 of our book which can be freely downloaded at the moment (7.2.3-7.2.5 in particular). We also use the dice loss in chapter 13 of the book.

Best regards


P.S.: If you enclose your code in lines with triple backticks (```), you get source code formatting.

Thank you so much for your reply, I will have a go with reading your book!
Just one more question, you said torch.nn.functional.softmax(score, dim=1)[:, 1] is same as torch.argmax(score,dim=1), but wouldn’t this always select the second index element no matter of probability of two elements?
Here is an example when I tried your method:
normal softmax: tensor([[0.0850, 0.9150],
[0.0598, 0.9402],
[0.0765, 0.9235],
[0.6879, 0.3121],
[0.1176, 0.8824],
[0.6816, 0.3184]], device=‘cuda:0’, grad_fn=)

After: tensor([0.9150, 0.9402, 0.9235, …, 0.3121, 0.8824, 0.3184], device=‘cuda:0’,

The tensor after ‘After:’ is the result of using your recommended line(torch.nn.functional.softmax(score, dim=1)[:, 1]) and as you can see it always return the second index element’s probability not an index and no matter of the probabilities between the first and second elements.

To form an intuition, make your values more extreme:

scores = torch.randn(5, 2) * 10
print(scores.softmax(dim=1)[:, 1])

this is what I mean by softmax being an approximation for argmax (of course this only works for 0 and 1 (two classes), in general softmax gives you an approximation of the 1-hot encoding of argmax).
If you take away the * 10 you have a similar effect (thresholding at 0.5 for the argmax), but the approximation is less close in general (which is good because gradients are better when it’s not so close to 0 or 1).

oh i see now why you said it’s an approximation.
But in order to compare with my target values, I need values either 0 or 1 (I think). So, do you think whether rounding after scores.softmax(dim=1)[:, 1] is a correct way of doing it?
For example,

pred = scores.softmax(dim=1)[:, 1]
pred = torch.round(pred)

def dice_loss(pred,target):
    with torch.enable_grad():
        numerator = 2 * torch.sum(pred * target)
        denominator = torch.sum(pred + target)
        return 1 - torch.true_divide((numerator + 1), (denominator + 1))

Well, that’s what they tell you in the definition of dice loss, but you don’t actually have to use it that way. If you want an intuition, think of the number as a probability of 1 in a Bernoulli distribution. Then the dice loss you get with probabilities is the one you would obtain (in the limit) if you would sample many copies and then take the dice loss over the lot of them (in one go, it’s not the expected dice loss but the dice loss of the expectation).

Thank you so much for the help!

Hi, just another question.
Dice loss you have implemented on the book is for binary class.
However, how can I construct dice loss for multiple classes?

Typically by taking 1-vs-all scores and then averaging them either weighted or unweighted.