Torch.ge triggers "there are no graph nodes that require computing gradients" in loss function


#1

Hello,

I am trying to implement Dice loss function and I need to find indices where the model predicts target class (I have segmentation problem with 2 classes only and I want my loss to depend only on predicted class, not the background).

When I add this line:

input = torch.ge(input, 0.5).float()

I get “RuntimeError: there are no graph nodes that require computing gradients”. When I remove this line, everything is great. Is it even possible to do this in PyTorch?


(Hugh Perkins) #2

Can you build this out of relu or similar somehow?


#3

Thx, for the advice. I have used torch.nn.Threshold and it looks like it works well