It looks like your issue is due to a troublesome bug in the innards of
autograd – not specific to torch.where(), but in lower-level infrastructure.
However, in your use case, you can work around it by clamping the
denominator of your potential divide-by-zero away from zero. Here
is an illustrative script that contains a modified version of your custom
loss function:
import torch
from torch import nn
import torch.nn.functional as F
print ('torch.__version__', torch.__version__)
torch.manual_seed (2021)
class ConditionalMeanRelativeLoss(nn.Module):
def __init__(self):
super(ConditionalMeanRelativeLoss, self).__init__()
def forward(self, output, target):
# calculate absolute errors
absolute_errors = torch.abs(torch.subtract(output, target))
# where target is too small, use just the absolute errors to avoid divide by 0
loss = torch.where(torch.abs(target) < 0.001, absolute_errors, torch.abs(torch.divide(absolute_errors, target)))
print ('pre-mean loss =', loss)
# return mean loss
return torch.mean(loss)
class ConditionalMeanRelativeLossB(nn.Module):
def __init__(self):
super(ConditionalMeanRelativeLossB, self).__init__()
def forward(self, output, target):
# calculate absolute errors
absolute_errors = torch.abs(torch.subtract(output, target))
# where target is too small, use just the absolute errors to avoid divide by 0
# but clamp abs (target) away from zero to avoid "ghost" divide by 0
abs_target = torch.abs (target).clamp (0.0005)
loss = torch.where(abs_target < 0.001, absolute_errors, torch.divide(absolute_errors, abs_target))
print ('pre-mean loss (B) =', loss)
# return mean loss
return torch.mean(loss)
outputA = torch.randn (5)
outputB = outputA.clone()
outputA.requires_grad = True
outputB.requires_grad = True
target = torch.randn (5)
target[2] = 0.0
target[3] = 0.0
print ('outputA =', outputA)
print ('outputB =', outputB)
print ('target =', target)
ConditionalMeanRelativeLoss() (outputA, target).backward()
print ('outputA.grad =', outputA.grad)
ConditionalMeanRelativeLossB() (outputB, target).backward()
print ('outputB.grad =', outputB.grad)
As to the autograd bug: A cluster of github issues shows that this is a
known problem. I don’t understand the details, but some of the comments
suggest that this bug might be tricky to fix, and perhaps won’t get fixed.
But I think (probably in general, not just in your use case) that if you
understand what is going on, you can work around it.