However when I do loss.backward(), the network says: ‘RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn’
I understand that torch.eq() will break the gradient, hence question 1 is ‘how can I circumvent this issue’
I need to count the occurrences of certain class in the torch tensor, hence Im not sure how I would avoid using torch.eq()
and secondly, does torch_tensor.bool().int() break gradient as well? If so, would switching to torch.gt(torch_logits, 1).sum() resolve the broken gradient?
Yes, torch.eq will return a BoolTensor and will thus detach the output from the computation graph. Unfortunately, I don’t know how to avoid this besides trying to com up with a custom backward function, i.e. you would have to define gradients for the equality/inequality of the values.
Yes, as gradients are defined for floating point types.
Generally, you could check if an operation detached the result by printing the .grad_fn of the inputs and outputs of a function. If the .grad_fn shows None, the tensor is detached.
If I were to custom define gradient for eq and not eq, along with logical_or(), logical_gt(), how would I do that? can you point me to some example/tutorial for this? Im fairly new to pytorch.
I am thinking just to pass along the gradient if equal, and reduce gradient to 0 if not equal.
same with gt, if value is greater than cmpared_value, pass along the gradient; otherwise reduce to 0
and for logical_or(), I’ll add the gradient together, e.g. True + True