Custom loss function failure: distance between two points on globe (Haversine Loss)

Hey, I’m new to Pytorch.

I created a Harversine Loss function

def haversine_loss(pred, target):
    lat1, lon1 = torch.split(pred, 1, dim=1)
    lat2, lon2 = torch.split(target, 1, dim=1)
    r = 6371  # Radius of Earth in kilometers
    phi1, phi2 = torch.deg2rad(lat1), torch.deg2rad(lat2)
    delta_phi, delta_lambda = torch.deg2rad(lat2-lat1), torch.deg2rad(lon2-lon1)
    a = torch.sin(delta_phi/2)**2 + torch.cos(phi1) * torch.cos(phi2) * torch.sin(delta_lambda/2)**2
    return tensor.mean(2 * r * torch.asin(torch.sqrt(a)))

It works with this simple neural net implemented in torch that’s mostly inspired from the tutorial.

# CREATE MODEL

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(86, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 2)
        )

    def forward(self, x):
        t = self.linear_relu_stack(x)
        return t

model = NeuralNetwork().to(device)
print(model)

# OPTIMIZING THE MODEL PARAMETERS

loss_fn = haversine_loss
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        # check_for_nan_and_save(pred, 'pred', batch)
        # check_for_nan_and_save(y, 'target', batch)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")

torch.autograd.set_detect_anomaly(True)

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

When I run this; at a certain point in the training run nans begin being produced (it looks something like below, this is not copied from the terminal):


loss: 132.595001 [125664/1370265]
loss: 478.590637 [132064/1370265]
loss: 687.429810 [138464/1370265]
loss: 470.115814 [144864/1370265]
loss: nan [151264/1370265]
loss: nan, etc, etc

I turned on torch.autograd.set_detect_anomaly(True) to see what’s going on.

I get the following:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 7
      5 for t in range(epochs):
      6     print(f"Epoch {t+1}\n-------------------------------")
----> 7     train(train_dataloader, model, loss_fn, optimizer)
      8     test(test_dataloader, model, loss_fn)
      9 print("Done!")

Cell In[11], line 16, in train(dataloader, model, loss_fn, optimizer)
     13 loss = loss_fn(pred, y)
     15 # Backpropagation
---> 16 loss.backward()
     18 max_grad = max(p.grad.abs().max().item() for p in model.parameters() if p.grad is not None)
     19 print(f'Max gradient: {max_grad}')

File ~/.local/lib/python3.10/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

RuntimeError: Function 'SplitBackward0' returned nan values in its 0th output.

This is not an exploding gradient problem, nor are the values being passed into the loss function nan. When I replace the loss function with torch’s MSE, the errors do not re-occur, so there must be something wrong the loss.

Any ideas?
Thanks.

Hello!

I would suggest you two things:

  1. Modify your loss function so you could check input shapes (there could be some shapes missmatch that doesnt interfere MSE loss calculation)
  2. Maybe you should get rid of the split function by using some vectorization (so torch anomaly detector will point to something else which may be helpful)

Hope I helped :slight_smile:

Are you normalizing the targets to a range of 0 and 1 or -1 and 1?

Neural networks like normalized values, so that the inputs and outputs fall in a range between 0 and 1 or -1 and 1. This typically will resolve a good deal of exploding loss/gradients problems.

For example:

r = 6371 #radius in km

def norm_circ_dist(radius, distance): # use to normalize distance related inputs
    circumference = 2 * radius * math.pi
    return distance / circumference

def actual_circ_dist(radius, rel_dist): # use to reverse the process
    circumference = 2 * radius * math.pi
    return rel_dist * circumference

Same goes for any angular predictions. You want the model inputs and outputs to stay in a normalized range. And then you can normalize the inputs and extrapolate the outputs.

Hi Sam!

While the functions asin() and sqrt() are, themselves, well defined when
their arguments are +-1.0 and 0.0, respectively, they are singular at those
values in that their derivations become undefined.

As a predicted point (pred) and a target point (target) approach one another,
you approach the sqrt (0.0) singularity. Note, minimizing your loss will drive
your predicted points to become equal to your targets, so this might happen
systematically as you train. Also, as pred and target approach being exactly
opposite one another on the globe – 180 degrees apart, for example, the north
and south poles of the globe – you approach the asin (+-1.0) singularity.

Yes, as mentioned above, your haversine_loss() contains singularities in its
gradient.

It might work to add a small epsilon to the arguments of asin() and sqrt() to
regulate those singularities. I’ve sketched such an approach, but if you use it
for anything real, you should test it carefully – working around (hacking?) such
numerical issues can be quite delicate and easy to get wrong.

Here is a script that illustrates a modified version of your haversine_loss():

import torch
print (torch.__version__)

def haversine_loss(pred, target, epsSq = 1.e-13, epsAs = 1.e-7):   # add optional epsilons to avoid singularities
    print ('haversine_loss: epsSq:', epsSq, ', epsAs:', epsAs)
    lat1, lon1 = torch.split(pred, 1, dim=1)
    lat2, lon2 = torch.split(target, 1, dim=1)
    r = 6371  # Radius of Earth in kilometers
    phi1, phi2 = torch.deg2rad(lat1), torch.deg2rad(lat2)
    delta_phi, delta_lambda = torch.deg2rad(lat2-lat1), torch.deg2rad(lon2-lon1)
    a = torch.sin(delta_phi/2)**2 + torch.cos(phi1) * torch.cos(phi2) * torch.sin(delta_lambda/2)**2
    # return tensor.mean(2 * r * torch.asin(torch.sqrt(a)))
    # "+ (1.0 - a**2) * epsSq" to keep sqrt() away from zero
    # "(1.0 - epsAs) *" to keep asin() away from plus-or-minus one
    return torch.Tensor.mean(2 * r * torch.asin ((1.0 - epsAs) * torch.sqrt (a + (1.0 - a**2) * epsSq)))

targ = torch.tensor ([
    [   0.0,   0.0],   # equator
    [   0.0,   0.0],   # equator
    [   0.0,   0.0],   # equator
    [   0.0, 180.0],   # opposite side of equator
    [   0.0, 180.0],   # opposite side of equator
    [  90.0,   0.0],   # north pole -- degenerate longitude
    [  90.0,   0.0],   # north pole -- degenerate longitude
    [ -90.0,   0.0],   # south pole -- degenerate longitude 
    [ -90.0,   0.0]    # south pole -- degenerate longitude 
])

pred = torch.tensor ([
    [ -15.0,  35.00],   # some point not near a singularity
    [   0.0,   0.00],   # equator -- same point, sqrt (0) error
    [   0.0,   0.05],   # equator -- slightly different point, okay
    [   0.0,   0.00],   # equator -- exactly opposite target, asin (1) error
    [   0.0, 179.90],   # equator -- not quite opposite, okay
    [  90.0,   0.00],   # north pole -- same point, sqrt (0) error
    [  89.9,   0.00],   # not quite north pole, okay
    [  90.0,   0.00],   # north pole -- opposite south pole, asin (1) error
    [  89.9,   0.00]    # not quite north pole -- not quite opposite, okay
], requires_grad = True)

loss = haversine_loss (pred, targ, epsSq = 0.0, epsAs = 0.0)   # turn off epsilons
print ('loss:', loss)

loss.backward()
print ('pred.grad = ...')
print (pred.grad)

pred.grad = None
loss = haversine_loss (pred, targ)                              # use epsilons
print ('loss:', loss)

loss.backward()
print ('pred.grad = ...')
print (pred.grad)

And here is its output

2.1.0
haversine_loss: epsSq: 0.0 , epsAs: 0.0
loss: tensor(7139.3516, grad_fn=<MeanBackward0>)
pred.grad = ...
tensor([[ -4.2835,  11.1938],
        [     nan,      nan],
        [ -0.0000,  12.3550],
        [     nan,      nan],
        [ -0.0000, -12.3550],
        [     nan,      nan],
        [-12.3550,   0.0000],
        [     nan,      nan],
        [ 12.7488,   0.0000]])
haversine_loss: epsSq: 1e-13 , epsAs: 1e-07
loss: tensor(7137.7847, grad_fn=<MeanBackward0>)
pred.grad = ...
tensor([[-4.2835e+00,  1.1194e+01],
        [-0.0000e+00, -0.0000e+00],
        [-0.0000e+00,  1.2355e+01],
        [-0.0000e+00,  1.1060e-03],
        [-0.0000e+00, -1.2355e+01],
        [ 0.0000e+00, -0.0000e+00],
        [-1.2355e+01,  0.0000e+00],
        [-1.1060e-03, -0.0000e+00],
        [ 1.1041e+01,  0.0000e+00]])

As an aside, although not directly related to you issue, using latitude and
longitude as coordinates on your globe gives you a singular coordinate system
in that longitude becomes degenerate at the north and south poles.

Good luck!

K. Frank