Torch.nonzero breaks gradient?

I’m defining a custom loss that finds FFT peaks (didn’t find anything similar so far).
I have found a peak picking method online that works well, but when using it in training I get a problem of grad_fn being none at some tensor.
(this is the method from stackoverflow: python - Pytorch Argrelmax function (or C++) - Stack Overflow)

I have been debugging the loss, and found that at this point the gradient breaks:
peaksHgen = torch.nonzero(hgpk2, out=None) + 1

What I see is that hgpk2 has grad_fn correctly set to WhereBackward0, while peaksHgen has grad_fn=None

Why is that? Doesn’t nonzero have a gradient function?
I’m not publishing my code here because it is identical to the one seen on stackoverflow besides changing variable names.

Hi @leopard86,

I’m pretty sure torch.nonzero doesn’t have a gradient function, because what would be the gradient function of a function that returns the index in an array whose value is not 0? In order for a function to have a derivative it needs to be locally continuous, which torch.nonzero is not.