Gradients for custom loss function using torch.where

I have a custom loss function which uses the torch.where function. It takes my model’s output, out (which has two columns) and does the following:

out = torch.randn((8, 2), requires_grad=True)
loss = torch.where(out[:, 0] > out[:, 1], 1., 0.).mean()

However, no gradients are carried through since torch.where uses the 1. and 0. values which do not have gradients. Is there any way to implement a loss function like this so that I can use loss.backward()?

technically, following may work

d = out[:,0]-out[:,1]
loss = (d.abs() * torch.heaviside(d, torch.zeros(1))).mean()

but it still may be difficult to optimize with this formulation, as gradients are discrete in (-1,0,1). Adam + lr scheduler may get you somewhere, if not, reformulate your loss with d**2 and constraints/penalties somehow.

1 Like