Hi everyone,
I’m trying to make a lstm based model for predicting sequences. I’ve listed the loss function and the model below:
def MetricOrthLoss(position_a, position_b,epsilon=1e-8):
# Transform into directional vector in Cartesian Coordinate System
norm_a = torch.sqrt(torch.square(position_a[:, :, 0:1]) + torch.square(position_a[:, :, 1:2]) + torch.square(position_a[:, :, 2:3]))+epsilon
norm_b = torch.sqrt(torch.square(position_b[:, :, 0:1]) + torch.square(position_b[:, :, 1:2]) + torch.square(position_b[:, :, 2:3]))+epsilon
x_true = position_a[:, :, 0:1]/norm_a
y_true = position_a[:, :, 1:2]/norm_a
z_true = position_a[:, :, 2:3]/norm_a
x_pred = position_b[:, :, 0:1]/norm_b
y_pred = position_b[:, :, 1:2]/norm_b
z_pred = position_b[:, :, 2:3]/norm_b
# Finally compute orthodromic distance
# great_circle_distance = np.arccos(x_true*x_pred+y_true*y_pred+z_true*z_pred)
# To keep the values in bound between -1 and 1
great_circle_distance = torch.acos(torch.clamp(x_true * x_pred + y_true * y_pred + z_true * z_pred, -1.0, 1.0))
return great_circle_distance.mean()
# This way we ensure that the network learns to predict the delta angle
def toPosition(values):
orientation = values[0]
motion = values[1]
return (orientation + motion)
class TRACK_MODEL(nn.Module):
def __init__(self,M_WINDOW,H_WINDOW,NUM_TILES_HEIGHT,NUM_TILES_WIDTH,input_size_pos=3,input_size_saliency=1,hidden_size=256):
super().__init__()
# Define encoders
self.pos_enc=nn.LSTM(input_size_pos,hidden_size=hidden_size,batch_first=True)
self.sal_enc=nn.LSTM(input_size_saliency*NUM_TILES_HEIGHT*NUM_TILES_WIDTH, hidden_size=hidden_size,batch_first=True)
self.fuse_1_enc=nn.LSTM(hidden_size*2,hidden_size=hidden_size,batch_first=True)
# Define decoders
self.pos_dec=nn.LSTM(input_size_pos,hidden_size=hidden_size,batch_first=True)
self.sal_dec=nn.LSTM(input_size_saliency*NUM_TILES_HEIGHT*NUM_TILES_WIDTH, hidden_size=hidden_size,batch_first=True)
self.fuse_1_dec=nn.LSTM(hidden_size*2,hidden_size=hidden_size,batch_first=True)
self.fc_1=nn.Linear(hidden_size,hidden_size)
self.fc_layer_out=nn.Linear(hidden_size,3)
self.output_horizon=H_WINDOW
self.input_window=M_WINDOW
def forward(self, encoder_pos_inputs,encoder_sal_inputs,decoder_pos_inputs,decoder_sal_inputs, verbose=False):
batch_size = encoder_pos_inputs.size(0)
# Encode position inputs
out_enc_pos, (h_n_pos,c_n_pos)= self.pos_enc(encoder_pos_inputs)
check_for_nans_or_infs(out_enc_pos, 'out_enc_pos')
# Flatten saliency input
flat_enc_sal_inputs = encoder_sal_inputs.view(batch_size, self.input_window, -1)
out_enc_sal, (h_n_sal,c_n_sal) = self.sal_enc(flat_enc_sal_inputs)
check_for_nans_or_infs(out_enc_sal, 'out_enc_sal')
# Concatenate encoder outputs
conc_out_enc = torch.cat([out_enc_sal, out_enc_pos], dim=-1)
fuse_out_enc, (h_n_fuse,c_n_fuse) = self.fuse_1_enc(conc_out_enc)
check_for_nans_or_infs(conc_out_enc, 'conc_out_enc')
check_for_nans_or_infs(fuse_out_enc, 'fuse_out_enc')
dec_input = decoder_pos_inputs
all_pos_outputs = []
for t in range(self.output_horizon):
# Decode pos at timestep
dec_pos_out, (h_n_pos,c_n_pos)= self.pos_dec(dec_input,(h_n_pos,c_n_pos))
check_for_nans_or_infs(dec_pos_out, 'dec_pos_out')
# Decode saliency at current timestep
selected_timestep_saliency = decoder_sal_inputs[:, t:t + 1].view(batch_size, 1, -1)
dec_sal_out, (h_n_sal,c_n_sal) = self.sal_dec(selected_timestep_saliency, (h_n_sal,c_n_sal))
check_for_nans_or_infs(dec_sal_out, 'dec_sal_out')
# Decode concatenated values
dec_out = torch.cat((dec_sal_out, dec_pos_out), dim=-1)
fuse_out_dec_1, (h_n_fuse, c_n_fuse) = self.fuse_1_dec(dec_out, (h_n_fuse, c_n_fuse))
check_for_nans_or_infs(dec_out, 'dec_out')
check_for_nans_or_infs(fuse_out_dec_1, 'fuse_out_dec_1')
# FC layers
dec_fuse_out = self.fc_1(fuse_out_dec_1)
outputs_delta = self.fc_layer_out(dec_fuse_out)
# Apply toposition
decoder_pred=toPosition([dec_input,outputs_delta])
check_for_nans_or_infs(dec_fuse_out, 'dec_fuse_out')
check_for_nans_or_infs(outputs_delta, 'outputs_delta')
check_for_nans_or_infs(decoder_pred, 'decoder_pred')
all_pos_outputs.append(decoder_pred)
dec_input=decoder_pred
decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)
return decoder_outputs_pos
The problem is that at some point loss.backward() returns nan values. I’ve checked all the inputs and confirmed that the pos_inputs are all 3d unit vectors while sal_inputs are HxW tensors with values between -1 and 1. The model returns a normal loss value (not nan) for the batch where the backwards step returns nan. I’ve tried clipping the gradients, lowering the learning rate, and using a different initialization approach, but it still eventually gets nan values within the first or second epoch. Setting detect anomaly to true points at:
decoder_outputs_pos=torch.cat(all_pos_outputs,dim=1)
…
in backward
torch.autograd.backward(
…
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function ‘CatBackward0’ returned nan values in its 10th output.
