How to ensure BP in custom loss function

Hello all, thanks in advance for your time.

I am writing a custom loss function. There are 8 methods that I use in the loss function. My network is for image segmentation, and outputs logits.

the custom loss function uses:


However when I do loss.backward(), the network says: ‘RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn’

I understand that torch.eq() will break the gradient, hence question 1 is ‘how can I circumvent this issue’

I need to count the occurrences of certain class in the torch tensor, hence Im not sure how I would avoid using torch.eq()

and secondly, does torch_tensor.bool().int() break gradient as well? If so, would switching to, 1).sum() resolve the broken gradient?

Yes, torch.eq will return a BoolTensor and will thus detach the output from the computation graph. Unfortunately, I don’t know how to avoid this besides trying to com up with a custom backward function, i.e. you would have to define gradients for the equality/inequality of the values.

Yes, as gradients are defined for floating point types.
Generally, you could check if an operation detached the result by printing the .grad_fn of the inputs and outputs of a function. If the .grad_fn shows None, the tensor is detached.

If I were to custom define gradient for eq and not eq, along with logical_or(), logical_gt(), how would I do that? can you point me to some example/tutorial for this? Im fairly new to pytorch.

I am thinking just to pass along the gradient if equal, and reduce gradient to 0 if not equal.
same with gt, if value is greater than cmpared_value, pass along the gradient; otherwise reduce to 0
and for logical_or(), I’ll add the gradient together, e.g. True + True


Sure, this tutorial shows how to implement custom autograd.Functions.