Gradient tracking in custom loss functions

Errors
I am trying to define a loss function for model training where the loss should be zero in many cases. I cannot simply return 0 in these cases because it returns an AttributeError, but if I try to return torch.Tensor(0) or the part of the function that sets it to zero, I get a Runtime Error that says element 0 of tensors does not require grad and does not have a grad_fn.

Any advice on how to resolve these errors to get the loss function working would be greatly appreciated!

Objective
The code below takes in two discrete distributions with the same support, and a weight vector of equal length to the distributions. It checks a binary condition if torch.all(diffs >= 0), and if that condtion holds then it returns a type of weighted distance between the distributions. If the condition fails the loss should be zero. Neither returning 0 nor returning torch.all(diffs >= 0) works for the failure case.

class MyLoss (torch.nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()
        
    def forward(self, predictions:torch.Tensor, target_dists:torch.Tensor, weights:torch.Tensor):
        pred_sums, target_sums = torch.cumsum(predictions,dim=1), torch.cumsum(target_dists,dim=1)
        diffs = torch.subtract(pred_sums, target_sums)
        if torch.all(diffs >= 0):
            return torch.dot(torch.subtract(predictions,target_dists),weights)
        else:
            return 0

I suspect the problem is that torch.all prevents gradient tracking, but I’m not sure how to fix it.

Could you simply skip the backwards step and gradient update for iterations/batches that meet this condition by checking if the returned loss is 0?

1 Like

That should work! Thanks. I probably should have thought of that :sweat_smile: