I have a custom loss function which uses the torch.where
function. It takes my model’s output, out
(which has two columns) and does the following:
out = torch.randn((8, 2), requires_grad=True)
loss = torch.where(out[:, 0] > out[:, 1], 1., 0.).mean()
However, no gradients are carried through since torch.where
uses the 1. and 0. values which do not have gradients. Is there any way to implement a loss function like this so that I can use loss.backward()
?