I am trying to implement Dice loss function and I need to find indices where the model predicts target class (I have segmentation problem with 2 classes only and I want my loss to depend only on predicted class, not the background).
When I add this line:
input = torch.ge(input, 0.5).float()
I get "RuntimeError: there are no graph nodes that require computing gradients". When I remove this line, everything is great. Is it even possible to do this in PyTorch?