When I implement Dice loss, the size of outputs is Batch x Channel x Width x Height while the target is 1-hot encoded as Batch x Width x Height. So I use
_,o = torch.max(outputs,1)
But when call backward(), the error came out:
RuntimeError: there are no graph nodes that require computing gradients
Is that because torch.max is not differentiable? I’m wondering how can I work around it? Or I have to write backpropagation part by myself?
Hi Smth,
Thank you for your reply.
I know if torch.max is not differentiable, we cannot write the backward function. But I’m wondering if there is any workaround that do not use torch.max but achieve the same idea.
I think I found maybe I could use torch.nn.Threshold instead.
In fact, I think it’s pretty common use torch.max in loss function.
Like this: https://github.com/mattmacy/torchbiomed/blob/master/torchbiomed/loss.py#L32
If the input size is Batch x ClassNumber x H x W while the target size is Batch x H x W, many applications would use torch.max to find the right label and compare with the target.
If you are doing an evaluation, There is no need to call backward.
If you are doing training, then you can use the Softmax function - Which is a differentiable approximation to the max function.
Thank you for your reply. But I could not figure out how nn.Softmax could work as torch.max function?
Could you please give some more details? Thanks so much.
For instance, here’s the DICE cost function I have. How can I use softmax replace the max op?
My suggestion -
First, you encode your target as a One hot encoding - where the target will be 1 for the correct class index, rest zeroes. Thus, the target will be of dimension (B x C x H x W) where B - batch, C - Number of classes, H, W - Height and width.
Then you can apply the softmax function to your outputs, which is also of dimension (B x C x H x W).
Remember to use the negative of this score while minimising. But I am not sure about convergence of this
# Replace outputs.max with this
outputs = outputs.permute(0,2,3,1).contiguous()
outputs = outputs.view(output.numel() // C, C) # ( B x H x W, C)
outputs = torch.nn.Softmax(outputs) #( B x H x W, C),
# Probabilities over C classes for each Pixel
targets = targets.permute(0,2,3,1).contiguous()
targets = targets.view(targets.numel() // C, C) # ( B x H x W, C)
# Remaining part of code will be same as yours
# Your code considers each class with equal weight. So if you have too much background,
# The loss due to foreground might be overshadowed. Consider weighting the dice_loss
# component of each class separately by taking the sum without flattening and later summing over
# each class with a weight attached to it.