I am trying to define a loss function for model training where the loss should be zero in many cases. I cannot simply return 0 in these cases because it returns an AttributeError, but if I try to return torch.Tensor(0) or the part of the function that sets it to zero, I get a Runtime Error that says
element 0 of tensors does not require grad and does not have a grad_fn.
Any advice on how to resolve these errors to get the loss function working would be greatly appreciated!
The code below takes in two discrete distributions with the same support, and a weight vector of equal length to the distributions. It checks a binary condition
if torch.all(diffs >= 0), and if that condtion holds then it returns a type of weighted distance between the distributions. If the condition fails the loss should be zero. Neither returning 0 nor returning torch.all(diffs >= 0) works for the failure case.
class MyLoss (torch.nn.Module): def __init__(self): super(MyLoss, self).__init__() def forward(self, predictions:torch.Tensor, target_dists:torch.Tensor, weights:torch.Tensor): pred_sums, target_sums = torch.cumsum(predictions,dim=1), torch.cumsum(target_dists,dim=1) diffs = torch.subtract(pred_sums, target_sums) if torch.all(diffs >= 0): return torch.dot(torch.subtract(predictions,target_dists),weights) else: return 0
I suspect the problem is that torch.all prevents gradient tracking, but I’m not sure how to fix it.