Nan Loss coming after some time

I’m also encountering a similar problem for my model. After a few iterations of training on graph data, loss which is MSELoss function between the returned output and a fixed label become NaN.

Model:

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import networkx as nx
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class Model(nn.Module):
    def __init__(self, nin, nhid1, nout, inp_l, hid_l, out_l=1):
        super(Model, self).__init__()

        self.g1 = GCNConv(in_channels= nin, out_channels= nhid1)
        self.g2 = GCNConv(in_channels= nhid1, out_channels= nout)
        self.dropout = 0.5
        self.lay1 = nn.Linear(inp_l ,hid_l)
        self.lay2 = nn.Linear(hid_l ,out_l)

    def forward(self, x, adj):
        x = F.relu(self.g1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.g2(x, adj)
        
        x = self.lay1(x)
        x = F.relu(x)
        x = self.lay2(x)
        x = F.relu(x)
        
        return x

The inputs to the model:

x (Tensor , optional ) – Node feature matrix with shape [num_nodes, num_node_features] .
edge_index (LongTensor , optional ) – Graph connectivity in COO format with shape [2, num_edges]

Here num_nodes=1000 ; num_node_features=1 ; num_edges = 5000

[GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv) is a graph embedder returns a [num_nodes, dim] matrix

1 Like

I would recommend to check the input and targets for invalid values first.
If there are no invalid values, then you could observe the loss and check if it’s blowing up (if so reduce the learning rate). If that’s also not the case, then you should check all intermediate activations to track down the operation, which creates the NaNs e.g. via forward hooks.

2 Likes

The inputs don’t have any invalid values and the loss is also close to 0. Can you please elaborate on the checking of intermediate activation functions on how exactly to detect the issue there? Thanks,

You can use forward hooks as described here to check all intermediate outputs for NaN values.
Since the inputs are valid and the loss doesn’t seem to explode, I guess a particular layer might create these invalid outputs, which are then propagated to the loss calculation.

I checked again. My loss function seems fixated and isn’t changing with the no. of epochs:

Epoch: 0001 loss_train: -0.0155 time: 0.0748s
Time Net =  0.07779312133789062
Epoch: 0051 loss_train: -0.0154 time: 0.0160s
Time Net =  1.1269891262054443
Epoch: 0101 loss_train: -0.0153 time: 0.0170s
Time Net =  1.9906792640686035
Epoch: 0151 loss_train: -0.0153 time: 0.0170s
Time Net =  2.8633458614349365

So, I checked for the gradient flow.

Apparently increasing the number of linear layers still gives the same loss_train value and gradient flow remains the same.

What could be the reason here?

I would focus on the NaN issue first before diving into the model and checking the gradient flow, as the former issue is more concerning.
Were you able to isolate the first NaN output?

I figured out that the loss is becoming NaN due to an objective function which I’m using for unsupervised learning. One of the steps in this objective function is to normalize the output of of the model i.e., Y/ ||Y||.

Whenever Y is a zero tensor, this normalization creates a NaN output. So, there is some issue with the model as the output comes out as a zero tensor sometimes. Also, as mentioned in the previous post, the loss doesn’t tend to change with increasing no. of epochs.

So, what should I do to check why the output is coming as a zero-tensor?

1 Like

I’m not sure what to check first. Is your objective function forcing the model to output a zero tensor in some way, which might then create the NaN outputs?

The architecture is as shown below. The model output is fed into the objective function. The output of the objective function and the model are fed into loss function.

But, the model output itself is coming out as a zero-tensor which will be fed into the objective function. So, what is to be done here? Any ideas why the model output comes across as a zero-valued tensor?
Thanks.

Also, the loss doesn’t seem to be changing as mentioned in the previous post. Is there way to check why this might be happening?

The loss seems to decrease, but really slow.
Checking the gradients, as you already did, is a valid way to see, if you have accidentally broken the computation graph. Since that doesn’t seem to be the case, you would have to verify if your approach using the custom objective function etc. works at all.
To do so I would recommend to try to overfit a small dataset, e.g. just 10 samples, and make sure your current training routine and model are able to overfit this dataset.

1 Like

I’m having a similar problem, and I’ve discovered where in the process the NaN is being produced, but I’m not sure why or how I can go about fixing it. I’m currently just trying to overfit to a single training example as a sanity check for model selection.

I’m using the mean distance between spherical coordinates as loss function (I’ve also experimented using MSE and RMSE but it doesn’t make a different to the loss becoming NaN) the code for this is below.

def distance_between_spherical_coordinates_rad(az1, ele1, az2, ele2, deg=True):
    """
    Angular distance between two spherical coordinates
    MORE: https://en.wikipedia.org/wiki/Great-circle_distance

    :return: angular distance in degrees or radians
    """
    dist = torch.sin(ele1) * torch.sin(ele2) + torch.cos(ele1) * torch.cos(ele2) * torch.cos(torch.abs(az1 - az2))
    # Making sure the dist values are in -1 to 1 range, else np.arccos kills the job
    dist = torch.clip(dist, -1, 1)
    if deg == True:
        dist = torch.arccos(dist) * 180 / torch.pi
    else:
        dist = torch.arccos(dist)
    return torch.sqrt(torch.mean(dist ** 2))

There appears to be no NaN values in the input or target and I’ve checked this with the below code.

with torch.autograd.detect_anomaly():
    for i in range(runs):
        optimizer.zero_grad()

        preds = model(batch_input)

        preds = utils.cart2sph(preds, deg=False)
        batch_target_sph = utils.cart2sph(batch_target,deg=False)

        if torch.isnan(preds[0]).any():
            pass # breakpoint is here
        if torch.isnan(preds[1]).any():
            pass # breakpoint is here

        if torch.isnan(batch_target_sph[0]).any():
            pass # breakpoint is here
        if torch.isnan(batch_target_sph[1]).any():
            pass # breakpoint is here



        loss = utils.distance_between_spherical_coordinates_rad(preds[0],preds[1], batch_target_sph[0], batch_target_sph[1])

        # loss = loss_fn(preds,batch_target)
        loss.backward()
        optimizer.step()

When running the training loop I get the below error message generated by the anomaly detector

[W ..\torch\csrc\autograd\python_anomaly_mode.cpp:104] Warning: Error detected in AcosBackward0. Traceback of forward call that caused the error:
  File "D:/Dan_PC_Stuff/Pycharm_projects/sceneGeneration/train.py", line 57, in <module>
    loss = utils.distance_between_spherical_coordinates_rad(preds[0],preds[1], batch_target_sph[0], batch_target_sph[1])
  File "D:\Dan_PC_Stuff\Pycharm_projects\sceneGeneration\utils.py", line 261, in distance_between_spherical_coordinates_rad
    dist = torch.arccos(dist) * 180 / torch.pi
 (function _print_stack)
Traceback (most recent call last):
  File "D:/Dan_PC_Stuff/Pycharm_projects/sceneGeneration/train.py", line 60, in <module>
    loss.backward()
  File "C:\Users\Audio\anaconda3\lib\site-packages\torch\_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\Audio\anaconda3\lib\site-packages\torch\autograd\__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'AcosBackward0' returned nan values in its 0th output.

It seems to be happening within the backwards pass of the loss function, specifically the pass through dist = torch.arccos(dist) * 180 / torch.pi which is odd as that line of code shouldn’t triggered as I’ve got deg=False within the training loop.

I think this code is causing the issue:

    # Making sure the dist values are in -1 to 1 range, else np.arccos kills the job
    dist = torch.clip(dist, -1, 1)
    if deg == True:
        dist = torch.arccos(dist) * 180 / torch.pi
    else:
        dist = torch.arccos(dist)

While arccos will accept inputs in [-1, 1], the grads would be -Inf at these edge points:

dist = torch.tensor([1.], requires_grad=True)
out = torch.arccos(dist)
out.backward()
print(dist.grad)
# tensor([-inf])
1 Like

What can I do to try and work around this?

I had found this solution here that suggests using a value to bring the range just inside of [-1,1]

But trying a range of epsilon values seems to result in a grad of 0 such as below. Even using a value of 1e-4, which seems much larger than the proposed in the link above.

For reference I’m using dtype = torch.float32

epsilon = 1e-4
dist = torch.tensor([1.], requires_grad = True)
out = torch.arccos(torch.clamp(dist, -1+epsilon, 1-epsilon))
out.backward()
print(dist.grad)
# tensor([0.])


However if I used a direct value for dist of 0.9999 it does result in a grad that is non-zero.

dist = torch.tensor([0.9999], requires_grad = True)
out = torch.arccos(dist)
out.backward()
print(dist.grad)
# tensor([-70.7048])

Edit:

After some advice from a colleague and some further searches I found this that explains torch.clamp() isn’t differentiable at min or max, and therefore sets the grads to 0. Any way i can get around this, or shouldn’t this be a major problem.

I don’t think this is the case here.
Testing dist at exactly the edge points still yields a valid value of 1 in clamp:

dist = torch.tensor([1.0], requires_grad = True)
torch.clamp(dist, -1, 1).backward()
print(dist.grad)
# tensor([1.])

The zero gradient you are seeing in out = torch.arccos(torch.clamp(dist, -1+epsilon, 1-epsilon)) is expected since the gradient outside of the clamped range is zero as seen here:

x = torch.linspace(-1.5, 1.5, 100)
eps = 0.1
y = torch.clamp(x, -1+eps, 1-eps)
plt.plot(x, y)

image

1 Like

Mate, you have the patience of a saint.

1 Like

HI, i found that my loss becomes nan after almost 14 epochs. Using set anomaly detection to true helped me find that LogSoftmax was getting zero inputs due to which losses had suddenly become nan. Back-tracing the inputs to see how they suddenly became zero, I found that at the middle of the 14 epoch, the inputs become zero tensors after transformation and min-max rescaling on these zero tensors were causing nan inputs to the model.

However, through the previous 13 epochs, the dataloader must have iterated the whole dataset through these transforms and did not seem to have these zero tensor problem. I do not understand why this problem happens after 13 passes over the dataset have all happened successfully. Why does the zero tensor problem not come up during the first epoch itself?

Could you help me get to the cause of the issue?

Could you check if these transformations are applied inplace on the inputs? You could check it by printing the stats of the batch in each epoch to track the min., mean, and max. values.

No, the transformation are not in-place. The min and max seem normal for a while and become 0.0 for both suddenly at a later epoch.

At each epoch, the system sees the entire dataset. If the transforms cause this issue on a particular data point, shouldn’t that happen at some point during each epoch when that datapoint is encountered? How does it happen only after few epochs have passed?

Yes, I would assume the issue should be visible in the first epoch, of the issue is really caused by the transformation and if it’s deterministic. Since you’ve confirmed you are not transforming the samples inplace, I don’t have another idea besides checking if the transformations are applied randomly somehow.