I have a custom loss function defined and I hit a wall debugging it. It is designed to return loss that is scaled according to the output value:
import torch
from torch import nn
import torch.nn.functional as F
class ConditionalMeanRelativeLoss(nn.Module):
def __init__(self):
super(ConditionalMeanRelativeLoss, self).__init__()
def forward(self, output, target):
# calculate absolute errors
absolute_errors = torch.abs(torch.subtract(output, target))
# where target is too small, use just the absolute errors to avoid divide by 0
loss = torch.where(torch.abs(target) < 0.001, absolute_errors, torch.abs(torch.divide(absolute_errors, target)))
# return mean loss
return torch.mean(loss)
I was conscious that I might create a divide by 0 error, so I use a “where” to try to avoid it. This is the first custom loss function I have ever defined, and when I use it, it returns all nan values. I used the torch anomaly detection and I saw this error:
/opt/miniconda3/envs/torch/lib/python3.8/site-packages/torch/autograd/__init__.py:145: UserWarning: Error detected in DivBackward0. Traceback of forward call that caused the error:
File "HybridMethodConfig1.py", line 322, in <module>
loss = train_model(deriv, derivtrainloader, DE_loss_fn, DE_optim, DEVICE)
File "/Users/henrydikeman/github/CombustTorch/auto_ode/ModelUtilities.py", line 44, in train_model
batch_loss = loss_fn(predictions, batch_results)
File "/opt/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/henrydikeman/github/CombustTorch/auto_ode/CustomLossFunctions.py", line 17, in forward
loss = torch.where(torch.abs(target) < 0.005, absolute_errors, torch.abs(torch.divide(absolute_errors, target)))
(Triggered internally at /Users/distiller/project/conda/conda-bld/pytorch_1614389903258/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
Variable._execution_engine.run_backward(
0%| | 0/1483 [00:00<?, ?it/s]
Traceback (most recent call last):
File "HybridMethodConfig1.py", line 322, in <module>
loss = train_model(deriv, derivtrainloader, DE_loss_fn, DE_optim, DEVICE)
File "/Users/henrydikeman/github/CombustTorch/auto_ode/ModelUtilities.py", line 47, in train_model
batch_loss.backward()
File "/opt/miniconda3/envs/torch/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/opt/miniconda3/envs/torch/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.
Before this I have been using the built-in MSE loss, so I just subbed out the function and treated it as a drop-in replacement. I was fairly sure that torch.where is reverse differentiable, but then again I am not totally sure. You can see I kind of went overboard with the torch operations trying to track down the issue.
I am 1000% sure my code worked before with MSE loss, so unless I need to treat this function different than MSE loss my code besides this should be good.
Edit: I tried while taking out the line with “torch.where” and it worked. So I guess I’m asking if there is any way I can get this elementwise conditional logic to work.
It looks like your issue is due to a troublesome bug in the innards of
autograd – not specific to torch.where(), but in lower-level infrastructure.
However, in your use case, you can work around it by clamping the
denominator of your potential divide-by-zero away from zero. Here
is an illustrative script that contains a modified version of your custom
loss function:
import torch
from torch import nn
import torch.nn.functional as F
print ('torch.__version__', torch.__version__)
torch.manual_seed (2021)
class ConditionalMeanRelativeLoss(nn.Module):
def __init__(self):
super(ConditionalMeanRelativeLoss, self).__init__()
def forward(self, output, target):
# calculate absolute errors
absolute_errors = torch.abs(torch.subtract(output, target))
# where target is too small, use just the absolute errors to avoid divide by 0
loss = torch.where(torch.abs(target) < 0.001, absolute_errors, torch.abs(torch.divide(absolute_errors, target)))
print ('pre-mean loss =', loss)
# return mean loss
return torch.mean(loss)
class ConditionalMeanRelativeLossB(nn.Module):
def __init__(self):
super(ConditionalMeanRelativeLossB, self).__init__()
def forward(self, output, target):
# calculate absolute errors
absolute_errors = torch.abs(torch.subtract(output, target))
# where target is too small, use just the absolute errors to avoid divide by 0
# but clamp abs (target) away from zero to avoid "ghost" divide by 0
abs_target = torch.abs (target).clamp (0.0005)
loss = torch.where(abs_target < 0.001, absolute_errors, torch.divide(absolute_errors, abs_target))
print ('pre-mean loss (B) =', loss)
# return mean loss
return torch.mean(loss)
outputA = torch.randn (5)
outputB = outputA.clone()
outputA.requires_grad = True
outputB.requires_grad = True
target = torch.randn (5)
target[2] = 0.0
target[3] = 0.0
print ('outputA =', outputA)
print ('outputB =', outputB)
print ('target =', target)
ConditionalMeanRelativeLoss() (outputA, target).backward()
print ('outputA.grad =', outputA.grad)
ConditionalMeanRelativeLossB() (outputB, target).backward()
print ('outputB.grad =', outputB.grad)
As to the autograd bug: A cluster of github issues shows that this is a
known problem. I don’t understand the details, but some of the comments
suggest that this bug might be tricky to fix, and perhaps won’t get fixed.
But I think (probably in general, not just in your use case) that if you
understand what is going on, you can work around it.