Got "no graph nodes that require computing gradients" when use torch.max?

Hi,

When I implement Dice loss, the size of outputs is Batch x Channel x Width x Height while the target is 1-hot encoded as Batch x Width x Height. So I use

_,o = torch.max(outputs,1)

But when call backward(), the error came out:

RuntimeError: there are no graph nodes that require computing gradients

Is that because torch.max is not differentiable? I’m wondering how can I work around it? Or I have to write backpropagation part by myself?

Thanks

the max operation is not differentiable. how will you write the back-propagation yourself if it is not differentiable?

Hi Smth,
Thank you for your reply.
I know if torch.max is not differentiable, we cannot write the backward function. But I’m wondering if there is any workaround that do not use torch.max but achieve the same idea.

I think I found maybe I could use torch.nn.Threshold instead.

In fact, I think it’s pretty common use torch.max in loss function.
Like this: https://github.com/mattmacy/torchbiomed/blob/master/torchbiomed/loss.py#L32
If the input size is Batch x ClassNumber x H x W while the target size is Batch x H x W, many applications would use torch.max to find the right label and compare with the target.

Qi

If you are doing an evaluation, There is no need to call backward.
If you are doing training, then you can use the Softmax function - Which is a differentiable approximation to the max function.

Hi ImgPreSng,

Thank you for your reply. But I could not figure out how nn.Softmax could work as torch.max function?
Could you please give some more details? Thanks so much.

For instance, here’s the DICE cost function I have. How can I use softmax replace the max op?

def forward(self, outputs, targets):
    smooth = .00001
    _, o = outputs.max(1)
    iflat = o.view(-1).float()
    tflat = targets.view(-1).float()
    intersection = (iflat * tflat).sum()

    return  1- ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

My suggestion -
First, you encode your target as a One hot encoding - where the target will be 1 for the correct class index, rest zeroes. Thus, the target will be of dimension (B x C x H x W) where B - batch, C - Number of classes, H, W - Height and width.

Then you can apply the softmax function to your outputs, which is also of dimension (B x C x H x W).

Remember to use the negative of this score while minimising. But I am not sure about convergence of this

# Replace outputs.max with this
outputs = outputs.permute(0,2,3,1).contiguous()
outputs = outputs.view(output.numel() // C, C) # ( B x H x W, C)
outputs = torch.nn.Softmax(outputs) #( B x H x W, C), 
# Probabilities over C classes for each Pixel

targets = targets.permute(0,2,3,1).contiguous()
targets = targets.view(targets.numel() // C, C) # ( B x H x W, C)

# Remaining part of code will be same as yours
# Your code considers each class with equal weight. So if you have too much background,
# The loss due to foreground might be overshadowed. Consider weighting the dice_loss
# component of each class separately by taking the sum without flattening and later summing over 
# each class with a weight attached to it.

You should also take a look at NLL_Loss2D for Segmentation Problems - http://pytorch.org/docs/master/nn.html#torch.nn.NLLLoss2d

References

Ahhh, got it. Thanks so much.
Just a minor error here I think:

outputs = torch.nn.Softmax()(outputs)

And you are right, I need to put some weights for such unbalanced data.