Loss.backward() returns nan

Hi everyone,
I’m trying to make a lstm based model for predicting sequences. I’ve listed the loss function and the model below:

def MetricOrthLoss(position_a, position_b,epsilon=1e-8):
    # Transform into directional vector in Cartesian Coordinate System
    norm_a = torch.sqrt(torch.square(position_a[:, :, 0:1]) + torch.square(position_a[:, :, 1:2]) + torch.square(position_a[:, :, 2:3]))+epsilon
    norm_b = torch.sqrt(torch.square(position_b[:, :, 0:1]) + torch.square(position_b[:, :, 1:2]) + torch.square(position_b[:, :, 2:3]))+epsilon
    x_true = position_a[:, :, 0:1]/norm_a
    y_true = position_a[:, :, 1:2]/norm_a
    z_true = position_a[:, :, 2:3]/norm_a
    x_pred = position_b[:, :, 0:1]/norm_b
    y_pred = position_b[:, :, 1:2]/norm_b
    z_pred = position_b[:, :, 2:3]/norm_b
    # Finally compute orthodromic distance
    # great_circle_distance = np.arccos(x_true*x_pred+y_true*y_pred+z_true*z_pred)
    # To keep the values in bound between -1 and 1
    great_circle_distance = torch.acos(torch.clamp(x_true * x_pred + y_true * y_pred + z_true * z_pred, -1.0, 1.0))
    return great_circle_distance.mean()

# This way we ensure that the network learns to predict the delta angle
def toPosition(values):
    orientation = values[0]
    motion = values[1]
    return (orientation + motion)

class TRACK_MODEL(nn.Module):
    def __init__(self,M_WINDOW,H_WINDOW,NUM_TILES_HEIGHT,NUM_TILES_WIDTH,input_size_pos=3,input_size_saliency=1,hidden_size=256):
        super().__init__()
        # Define encoders
        self.pos_enc=nn.LSTM(input_size_pos,hidden_size=hidden_size,batch_first=True)
        self.sal_enc=nn.LSTM(input_size_saliency*NUM_TILES_HEIGHT*NUM_TILES_WIDTH, hidden_size=hidden_size,batch_first=True)
        self.fuse_1_enc=nn.LSTM(hidden_size*2,hidden_size=hidden_size,batch_first=True)
        
        # Define decoders
        self.pos_dec=nn.LSTM(input_size_pos,hidden_size=hidden_size,batch_first=True)
        self.sal_dec=nn.LSTM(input_size_saliency*NUM_TILES_HEIGHT*NUM_TILES_WIDTH, hidden_size=hidden_size,batch_first=True)
        self.fuse_1_dec=nn.LSTM(hidden_size*2,hidden_size=hidden_size,batch_first=True)
        
        self.fc_1=nn.Linear(hidden_size,hidden_size)
        self.fc_layer_out=nn.Linear(hidden_size,3)
        self.output_horizon=H_WINDOW
        self.input_window=M_WINDOW
        

    def forward(self, encoder_pos_inputs,encoder_sal_inputs,decoder_pos_inputs,decoder_sal_inputs, verbose=False):
        batch_size = encoder_pos_inputs.size(0)
        
        # Encode position inputs
        out_enc_pos, (h_n_pos,c_n_pos)= self.pos_enc(encoder_pos_inputs)
        check_for_nans_or_infs(out_enc_pos, 'out_enc_pos')
        
        # Flatten saliency input
        flat_enc_sal_inputs = encoder_sal_inputs.view(batch_size, self.input_window, -1)
        out_enc_sal, (h_n_sal,c_n_sal) = self.sal_enc(flat_enc_sal_inputs)
        check_for_nans_or_infs(out_enc_sal, 'out_enc_sal')
        
        # Concatenate encoder outputs
        conc_out_enc = torch.cat([out_enc_sal, out_enc_pos], dim=-1)
        fuse_out_enc, (h_n_fuse,c_n_fuse) = self.fuse_1_enc(conc_out_enc)
        check_for_nans_or_infs(conc_out_enc, 'conc_out_enc')
        check_for_nans_or_infs(fuse_out_enc, 'fuse_out_enc')
        
        dec_input = decoder_pos_inputs
        all_pos_outputs = []
        for t in range(self.output_horizon):
            
            # Decode pos at timestep
            dec_pos_out, (h_n_pos,c_n_pos)= self.pos_dec(dec_input,(h_n_pos,c_n_pos))
            check_for_nans_or_infs(dec_pos_out, 'dec_pos_out')
            
            # Decode saliency at current timestep
            selected_timestep_saliency = decoder_sal_inputs[:, t:t + 1].view(batch_size, 1, -1) 
            dec_sal_out, (h_n_sal,c_n_sal) = self.sal_dec(selected_timestep_saliency, (h_n_sal,c_n_sal))
            check_for_nans_or_infs(dec_sal_out, 'dec_sal_out')
            
            # Decode concatenated values
            dec_out = torch.cat((dec_sal_out, dec_pos_out), dim=-1)
            fuse_out_dec_1, (h_n_fuse, c_n_fuse) = self.fuse_1_dec(dec_out, (h_n_fuse, c_n_fuse))
            check_for_nans_or_infs(dec_out, 'dec_out')
            check_for_nans_or_infs(fuse_out_dec_1, 'fuse_out_dec_1')
            
            # FC layers
            dec_fuse_out = self.fc_1(fuse_out_dec_1)
            outputs_delta = self.fc_layer_out(dec_fuse_out)
            # Apply toposition
            decoder_pred=toPosition([dec_input,outputs_delta])
            check_for_nans_or_infs(dec_fuse_out, 'dec_fuse_out')
            check_for_nans_or_infs(outputs_delta, 'outputs_delta')
            check_for_nans_or_infs(decoder_pred, 'decoder_pred')
            all_pos_outputs.append(decoder_pred)
            dec_input=decoder_pred
        decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)
        return decoder_outputs_pos

The problem is that at some point loss.backward() returns nan values. I’ve checked all the inputs and confirmed that the pos_inputs are all 3d unit vectors while sal_inputs are HxW tensors with values between -1 and 1. The model returns a normal loss value (not nan) for the batch where the backwards step returns nan. I’ve tried clipping the gradients, lowering the learning rate, and using a different initialization approach, but it still eventually gets nan values within the first or second epoch. Setting detect anomaly to true points at:
decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)

in backward
torch.autograd.backward(

Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function ‘CatBackward0’ returned nan values in its 10th output.

For a specific run, the model reached till batch 143 before the values became nan.
The losses were going down as expected


and the loss for the batch was 0.4881.
But the loss.backward returns nan at catbackward0 and all the gradients become nan. The learning rate is 0.0005 but I’ve also tried with 0.0001 and 0.00005 with similar results.

Run your code within a torch.autograd.detect_anomaly(check_nan=True) context manager to detect with torch.cat operation is causing the NaN, docs here: Automatic differentiation package - torch.autograd — PyTorch 2.4 documentation

I ran it within the context manager and got this output

....
  File "C:\Users\Varun\AppData\Local\Temp\ipykernel_23572\4244610133.py", line 3, in <module>
    prediction=model(b143['encoder_pos_inputs'],b143['encoder_sal_inputs'],b143['decoder_pos_inputs'],b143['decoder_sal_inputs'])
  File "c:\Users\Varun\Desktop\Projects....\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Users\Varun\Desktop\Projects\....\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "c:\Users\Varun\Desktop\Projects\....\TRACK_SAL.py", line 130, in forward
    decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)
 (Triggered internally at ..\torch\csrc\autograd\python_anomaly_mode.cpp:118.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

I’m running it on a checkpoint of the model trained up to the batch that resulted in nan values. And b143 has the input of the batch that caused the crash for this particular run.
Not entirely sure why this torch.cat operation is causing issues since the final output seems fine.

It’s not the torch.cat operation that’s causing the issue, it’s the operation after the torch.cat call as the operation of the graph are reversed when computing gradients.

The only operation made after that specific torch.cat is the loss calculation since it’s the last operation of the forward step, and that returns a finite loss for the batch. I’ve also added epsilon to any denominators just in case to avoid division by zero and I’ve clamped the values going into torch.acos

I’d check if this function has a gradient, as torch.clamp sets any values outside the clamp range to the respective edge, but it won’t have a gradient.

Other potential issues are torch.acos and torch.sqrt having ill-defined inputs, and norm_a and norm_b being equal to 0.

Hi, sorry for the delay. I found the issue. You were right, the loss function’s clamping was causing issues so I looked into the literature. It had a misquote and was using orthodromic distance as the loss but the original paper only used it as a separate test metric and used MSELoss as the loss function. Thanks for the help.