# Loss.backward() returns nan

Hi everyone,
I’m trying to make a lstm based model for predicting sequences. I’ve listed the loss function and the model below:

``````def MetricOrthLoss(position_a, position_b,epsilon=1e-8):
# Transform into directional vector in Cartesian Coordinate System
norm_a = torch.sqrt(torch.square(position_a[:, :, 0:1]) + torch.square(position_a[:, :, 1:2]) + torch.square(position_a[:, :, 2:3]))+epsilon
norm_b = torch.sqrt(torch.square(position_b[:, :, 0:1]) + torch.square(position_b[:, :, 1:2]) + torch.square(position_b[:, :, 2:3]))+epsilon
x_true = position_a[:, :, 0:1]/norm_a
y_true = position_a[:, :, 1:2]/norm_a
z_true = position_a[:, :, 2:3]/norm_a
x_pred = position_b[:, :, 0:1]/norm_b
y_pred = position_b[:, :, 1:2]/norm_b
z_pred = position_b[:, :, 2:3]/norm_b
# Finally compute orthodromic distance
# great_circle_distance = np.arccos(x_true*x_pred+y_true*y_pred+z_true*z_pred)
# To keep the values in bound between -1 and 1
great_circle_distance = torch.acos(torch.clamp(x_true * x_pred + y_true * y_pred + z_true * z_pred, -1.0, 1.0))
return great_circle_distance.mean()

# This way we ensure that the network learns to predict the delta angle
def toPosition(values):
orientation = values[0]
motion = values[1]
return (orientation + motion)

class TRACK_MODEL(nn.Module):
def __init__(self,M_WINDOW,H_WINDOW,NUM_TILES_HEIGHT,NUM_TILES_WIDTH,input_size_pos=3,input_size_saliency=1,hidden_size=256):
super().__init__()
# Define encoders
self.pos_enc=nn.LSTM(input_size_pos,hidden_size=hidden_size,batch_first=True)
self.sal_enc=nn.LSTM(input_size_saliency*NUM_TILES_HEIGHT*NUM_TILES_WIDTH, hidden_size=hidden_size,batch_first=True)
self.fuse_1_enc=nn.LSTM(hidden_size*2,hidden_size=hidden_size,batch_first=True)

# Define decoders
self.pos_dec=nn.LSTM(input_size_pos,hidden_size=hidden_size,batch_first=True)
self.sal_dec=nn.LSTM(input_size_saliency*NUM_TILES_HEIGHT*NUM_TILES_WIDTH, hidden_size=hidden_size,batch_first=True)
self.fuse_1_dec=nn.LSTM(hidden_size*2,hidden_size=hidden_size,batch_first=True)

self.fc_1=nn.Linear(hidden_size,hidden_size)
self.fc_layer_out=nn.Linear(hidden_size,3)
self.output_horizon=H_WINDOW
self.input_window=M_WINDOW

def forward(self, encoder_pos_inputs,encoder_sal_inputs,decoder_pos_inputs,decoder_sal_inputs, verbose=False):
batch_size = encoder_pos_inputs.size(0)

# Encode position inputs
out_enc_pos, (h_n_pos,c_n_pos)= self.pos_enc(encoder_pos_inputs)
check_for_nans_or_infs(out_enc_pos, 'out_enc_pos')

# Flatten saliency input
flat_enc_sal_inputs = encoder_sal_inputs.view(batch_size, self.input_window, -1)
out_enc_sal, (h_n_sal,c_n_sal) = self.sal_enc(flat_enc_sal_inputs)
check_for_nans_or_infs(out_enc_sal, 'out_enc_sal')

# Concatenate encoder outputs
conc_out_enc = torch.cat([out_enc_sal, out_enc_pos], dim=-1)
fuse_out_enc, (h_n_fuse,c_n_fuse) = self.fuse_1_enc(conc_out_enc)
check_for_nans_or_infs(conc_out_enc, 'conc_out_enc')
check_for_nans_or_infs(fuse_out_enc, 'fuse_out_enc')

dec_input = decoder_pos_inputs
all_pos_outputs = []
for t in range(self.output_horizon):

# Decode pos at timestep
dec_pos_out, (h_n_pos,c_n_pos)= self.pos_dec(dec_input,(h_n_pos,c_n_pos))
check_for_nans_or_infs(dec_pos_out, 'dec_pos_out')

# Decode saliency at current timestep
selected_timestep_saliency = decoder_sal_inputs[:, t:t + 1].view(batch_size, 1, -1)
dec_sal_out, (h_n_sal,c_n_sal) = self.sal_dec(selected_timestep_saliency, (h_n_sal,c_n_sal))
check_for_nans_or_infs(dec_sal_out, 'dec_sal_out')

# Decode concatenated values
dec_out = torch.cat((dec_sal_out, dec_pos_out), dim=-1)
fuse_out_dec_1, (h_n_fuse, c_n_fuse) = self.fuse_1_dec(dec_out, (h_n_fuse, c_n_fuse))
check_for_nans_or_infs(dec_out, 'dec_out')
check_for_nans_or_infs(fuse_out_dec_1, 'fuse_out_dec_1')

# FC layers
dec_fuse_out = self.fc_1(fuse_out_dec_1)
outputs_delta = self.fc_layer_out(dec_fuse_out)
# Apply toposition
decoder_pred=toPosition([dec_input,outputs_delta])
check_for_nans_or_infs(dec_fuse_out, 'dec_fuse_out')
check_for_nans_or_infs(outputs_delta, 'outputs_delta')
check_for_nans_or_infs(decoder_pred, 'decoder_pred')
all_pos_outputs.append(decoder_pred)
dec_input=decoder_pred
decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)
return decoder_outputs_pos
``````

The problem is that at some point loss.backward() returns nan values. I’ve checked all the inputs and confirmed that the pos_inputs are all 3d unit vectors while sal_inputs are HxW tensors with values between -1 and 1. The model returns a normal loss value (not nan) for the batch where the backwards step returns nan. I’ve tried clipping the gradients, lowering the learning rate, and using a different initialization approach, but it still eventually gets nan values within the first or second epoch. Setting detect anomaly to true points at:
decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)

in backward

Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function ‘CatBackward0’ returned nan values in its 10th output.

For a specific run, the model reached till batch 143 before the values became nan.
The losses were going down as expected

and the loss for the batch was 0.4881.
But the loss.backward returns nan at catbackward0 and all the gradients become nan. The learning rate is 0.0005 but I’ve also tried with 0.0001 and 0.00005 with similar results.

Run your code within a `torch.autograd.detect_anomaly(check_nan=True)` context manager to detect with `torch.cat` operation is causing the NaN, docs here: Automatic differentiation package - torch.autograd — PyTorch 2.4 documentation

I ran it within the context manager and got this output

``````....
File "C:\Users\Varun\AppData\Local\Temp\ipykernel_23572\4244610133.py", line 3, in <module>
prediction=model(b143['encoder_pos_inputs'],b143['encoder_sal_inputs'],b143['decoder_pos_inputs'],b143['decoder_sal_inputs'])
File "c:\Users\Varun\Desktop\Projects....\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "c:\Users\Varun\Desktop\Projects\....\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "c:\Users\Varun\Desktop\Projects\....\TRACK_SAL.py", line 130, in forward
decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
``````

I’m running it on a checkpoint of the model trained up to the batch that resulted in nan values. And b143 has the input of the batch that caused the crash for this particular run.
Not entirely sure why this torch.cat operation is causing issues since the final output seems fine.

It’s not the `torch.cat` operation that’s causing the issue, it’s the operation after the `torch.cat` call as the operation of the graph are reversed when computing gradients.

The only operation made after that specific torch.cat is the loss calculation since it’s the last operation of the forward step, and that returns a finite loss for the batch. I’ve also added epsilon to any denominators just in case to avoid division by zero and I’ve clamped the values going into torch.acos

I’d check if this function has a gradient, as `torch.clamp` sets any values outside the clamp range to the respective edge, but it won’t have a gradient.

Other potential issues are `torch.acos` and `torch.sqrt` having ill-defined inputs, and `norm_a` and `norm_b` being equal to 0.

Hi, sorry for the delay. I found the issue. You were right, the loss function’s clamping was causing issues so I looked into the literature. It had a misquote and was using orthodromic distance as the loss but the original paper only used it as a separate test metric and used MSELoss as the loss function. Thanks for the help.