TypeError: isnan(): argument 'input' (position 1) must be Tensor, not PackedSequence

So I have been trying to pin point the source of nan values in my model. Here is part of my architecture

class LSTMBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, dropout=0, batchnorm=False, bias=False, num_layers=1, bidirectional=True):
        super().__init__()
        self._lstm = nn.LSTM(
                input_size=in_channels,
                hidden_size=out_channels,
                num_layers=num_layers,
                dropout=dropout,
                batch_first=True,
                bidirectional=bidirectional,
                bias=bias
        )
        self.n_dirs = 2 if bidirectional else 1
        self.fc_hid = nn.Linear(2*out_channels, out_channels)
        
        self.fc_out = nn.Linear(2*out_channels, out_channels)
        self.hidden_size = out_channels
        #initialize different layers
        init_weights(self._lstm)
        init_weights(self.fc_hid)
        init_weights(self.fc_out)
                        
    def forward(self, x):
        #(B,T,D )

        src_len = torch.LongTensor([torch.max((x[i,:, 0]!=0).nonzero()).item()+1 for i in range(x.shape[0])])
        
        
        packed_x = nn.utils.rnn.pack_padded_sequence(x, src_len.cpu().numpy(), batch_first=True)
        
        packed_outputs, hidden_state = self._lstm(packed_x)

        hidden = hidden_state[0]

        hidden = hidden[-self.n_dirs:, :, :]#(2,B,H)

        # pad packed output sequence (B,T,2*H )
        outputs, lengths = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        outputs = self.fc_out(outputs) #(B,:src_len, H)    

        hidden_state = torch.tanh(self.fc_hid(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
        

        return outputs, hidden_state

The error message that I recieved is

--> 438         encoded, _ = self._encoder(encoder_input)
    439 
    440         # Aggregator: take the mean over all points

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1206             input = bw_hook.setup_input_hook(input)
   1207 
-> 1208         result = forward_call(*input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:
   1210             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

/tmp/ipykernel_4888/1110288317.py in forward(self, x)
    234         packed_x = nn.utils.rnn.pack_padded_sequence(x, src_len.cpu().numpy(), batch_first=True)
    235 
--> 236         packed_outputs, hidden_state = self._lstm(packed_x)
    237 
    238         hidden = hidden_state[0]

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:
   1210             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1211                 hook_result = hook(self, input, result)
   1212                 if hook_result is not None:
   1213                     result = hook_result

/tmp/ipykernel_4888/1110288317.py in nan_hook(self, inp, out)
     59 
     60     for i, inp in enumerate(inputs):
---> 61         if inp is not None and contains_nan(inp):
     62             raise RuntimeError(f'Found NaN input at index: {i} in layer: {layer}')
     63 

/tmp/ipykernel_4888/1110288317.py in <lambda>(x)
     55     inputs = isinstance(inp, tuple) and inp or [inp]
     56 
---> 57     contains_nan = lambda x: torch.isnan(x).any()
     58     layer = self.__class__.__name__
     59 

TypeError: isnan(): argument 'input' (position 1) must be Tensor, not PackedSequence

The code runs on the validation dataset with the batch size one but when it runs to train the model with batch size of 40, then I got above error. I willl appreciate if someone can offer a way to solve this error.

You could access the .data attribute which will return the internally stored tensor from the PackedSequence and should work:

x = torch.cat([torch.randn(1, 10) for _ in range(5)], dim=0)
x = torch.nn.utils.rnn.pack_padded_sequence(x, lengths=[10]*5, batch_first=True)

print(torch.isnan(x))
# TypeError: isnan(): argument 'input' (position 1) must be Tensor, not PackedSequence

print(torch.isnan(x.data).any())
# tensor(False)

So it turned out that I got NaN values when I used the typical BatchNorm1d. The time-series data has different sequence length. I followed this discussion and wrote a custom class to implement BatchNorm but I still get NaN values

class MaskedNorm(nn.Module):
    def __init__(self, num_features, mask_on=True):
        """y is the input tensor of shape [batch_size,  time_length,n_channels]
            mask is of shape [batch_size, 1, time_length]
        """
        
        # The process:
        #  1. Merge the batch and time axes using reshape
        #  2. Create a dummy time axis at the end with size 1.
        #  3. Select the valid time steps using the mask
        #  4. Apply BatchNorm1d to the valid time steps
        #  5. Scatter the resulting values to the corresponding positions
        #  6. Unmerge the batch and time axes
        super().__init__()
        self.norm = nn.BatchNorm1d(num_features=num_features)
        self.num_features = num_features
        self.mask_on = mask_on
        #
    def forward(self, y, mask=None):
        #
        self.sequence_length = y.shape[1]
        if self.training and self.mask_on:
            if mask is None:
                seq_len = [torch.max((y[i,:, 0]!=0).nonzero()).item()+1 for i in range(y.shape[0])]
                m  = torch.zeros([y.shape[0],y.shape[1]+1], dtype=torch.bool).to(y.device)
                m[(torch.arange(y.shape[0]), seq_len)] = 1
                m  = m.cumsum(dim=1)[:, :-1] 
                mask = (1-m)
            reshaped = y.reshape([-1, self.num_features, 1])
            reshaped_mask = mask.reshape([-1, 1, 1]) > 0
            selected = torch.masked_select(reshaped, reshaped_mask).reshape([-1, self.num_features, 1])
            batch_normed = self.norm(selected)
            scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
            return scattered.reshape([ -1, self.sequence_length, self.num_features])
        else:
            reshaped = y.reshape([-1, self.num_features, 1])
            batched_normed = self.norm(reshaped)
            return batched_normed.reshape([ -1, self.sequence_length, self.num_features])

The input data before BatchNorm doesn’t have any NaN values. I am wondering why I still get NaN values?

Check the intermediate values to narrow down which operation is creating the NaN values.

Here is the error:

-> 1205         context_x = self.norm_u(inputs)
   1207         assert not torch.isnan(context_x).any()

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1206             input = bw_hook.setup_input_hook(input)
   1207 
-> 1208         result = forward_call(*input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:
   1210             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

/tmp/ipykernel_5278/2614583974.py in forward(self, y, mask)
    146             reshaped_mask = mask.reshape([-1, 1, 1]) > 0
    147             selected = torch.masked_select(reshaped, reshaped_mask).reshape([-1, self.num_features, 1])
--> 148             batch_normed = self.norm(selected)
    149             scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
    150             return scattered.reshape([ -1, self.sequence_length, self.num_features])

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:
   1210             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1211                 hook_result = hook(self, input, result)
   1212                 if hook_result is not None:
   1213                     result = hook_result

/tmp/ipykernel_5278/2614583974.py in nan_hook(self, inp, out)
     60     for i, inp in enumerate(inputs):
     61         if inp is not None and contains_nan(inp):
---> 62             raise RuntimeError(f'Found NaN input at index: {i} in layer: {layer}')
     63 
     64     for i, out in enumerate(outputs):

RuntimeError: Found NaN input at index: 0 in layer: BatchNorm1d

One possible reason I’ve got NaN values would be due to getting stadard deviation equal to zero. It might be because of masking data. I am wondering whether experts here have any suggestion about way to solve this problem? Many thanks.

That sounds plausible and you would need to check it.
Your current debugging approach already points towards the self.norm layer, so try to recompute the normalization and check where the invalid outputs are created.

I wrote a LayerNorm class and combined it with above MaskedNorm class

class LayerNorm_(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm_, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
            self.beta = nn.Parameter(torch.zeros(num_features))
    def forward(self, x):
        if x.shape[0]==1:
            #input (batch, sequence, feature)
            mean = x.mean(1)
            std  = x.std(1)
            y = (x - mean) / (std + self.eps)
            if self.affine:
                y = self.gamma * y + self.beta
        else:
            #input (batch*not masked sequence, feature)
            shape = [-1] + [1] * (x.dim() - 1)
            mean = x.mean(0, keepdim=True).view(*shape)
            std = x.std(0, keepdim=True).view(*shape)
            assert not torch.isnan(mean).any()
            assert not torch.isinf(mean).any()
            assert not torch.isnan(std).any()
            assert not torch.isinf(std).any()
            y = (x - mean.squeeze(-1)) / (std.squeeze(-1) + self.eps)
            assert not torch.isnan(y).any()
            if self.affine:
                shape = [1, -1] + [1] * (x.dim() - 2)
                y = self.gamma.view(*shape) * y + self.beta.view(*shape)
                assert not torch.isnan(y).any()
        return y

class MaskedNorm(nn.Module):
    #BatchNorm1d with mask
    def __init__(self, num_features, mask_on=True, norm_name=None, affine=True):
        super().__init__()
        if norm_name == "BN":
            self.norm = nn.BatchNorm1d(num_features)
        else:
            self.norm = LayerNorm_(num_features, affine = affine, eps=1e-04)
        self.num_features = num_features
        self.mask_on = mask_on
        self._norm_name = norm_name
        #
    def forward(self, y, mask=None):
        # #input BatchNorm (N,C,L)
        #LayerNorm: batch, sentence_length, embedding_dim
        self.sequence_length = y.shape[1]
        if self.training and self.mask_on:
            if mask is None:
                seq_len = [torch.max((y[i,:, 0]!=0).nonzero()).item()+1 for i in range(y.shape[0])]
                m  = torch.zeros([y.shape[0],y.shape[1]+1], dtype=torch.bool).to(y.device)
                m[(torch.arange(y.shape[0]), seq_len)] = 1
                m  = m.cumsum(dim=1)[:, :-1]
                mask = (1-m)
            if self._norm_name == "BN":
                reshaped = y.reshape([-1, self.num_features, 1])
                reshaped_mask = mask.reshape([-1, 1, 1]) > 0
                selected = torch.masked_select(reshaped, reshaped_mask).reshape([-1, self.num_features, 1])
                batch_normed = self.norm(selected)
                scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
                return scattered.reshape([ -1, self.sequence_length, self.num_features])
            else:
                reshaped = y.reshape([-1, self.num_features, 1])
                reshaped_mask = mask.reshape([-1, 1, 1]) > 0
                selected = torch.masked_select(reshaped, reshaped_mask).reshape([-1, self.num_features, 1])
                assert not torch.isnan(selected).any()
                batch_normed = self.norm(selected)
                assert not torch.isnan(batch_normed).any()
                scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
                return scattered.reshape([ -1, self.sequence_length, self.num_features])
        else:
            if self._norm_name == "BN":
                reshaped = y.reshape([-1, self.num_features, 1])
                batched_normed = self.norm(reshaped)
                return batched_normed.reshape([ -1, self.sequence_length, self.num_features])
            else:
                return self.norm(y)

I included some lines to show where NaNs happend. However, eventhough none of them show any assertion error, I got the same error for the LayerNorm.

Train Epoch: [    0/  400], Batch [     1/    67 (  1%)]	Learning rate: 5.00e-05	Loss: nan
---------------------------------------------------------------------------

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1206             input = bw_hook.setup_input_hook(input)
   1207 
-> 1208         result = forward_call(*input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:
   1210             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

/tmp/ipykernel_4798/2638772351.py in forward(self, u, y)
   1264         #mistmach between input???
   1265         if self._MaskedNorm:
-> 1266             normed_u = self.norm_u(reshaped_u)
   1267             normed_y = self.norm_y(reshaped_y)
   1268             assert not torch.isnan(normed_u).any()

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1206             input = bw_hook.setup_input_hook(input)
   1207 
-> 1208         result = forward_call(*input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:
   1210             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

/tmp/ipykernel_4798/2638772351.py in forward(self, y, mask)
    198                 selected = torch.masked_select(reshaped, reshaped_mask).reshape([-1, self.num_features, 1])
    199                 assert not torch.isnan(selected).any()
--> 200                 batch_normed = self.norm(selected)
    201                 assert not torch.isnan(batch_normed).any()
    202                 scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:
   1210             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1211                 hook_result = hook(self, input, result)
   1212                 if hook_result is not None:
   1213                     result = hook_result

/tmp/ipykernel_4798/2638772351.py in nan_hook(self, inp, out)
     60     for i, inp in enumerate(inputs):
     61         if inp is not None and contains_nan(inp):
---> 62             raise RuntimeError(f'Found NaN input at index: {i} in layer: {layer}')
     63 
     64     for i, out in enumerate(outputs):

RuntimeError: Found NaN input at index: 0 in layer: LayerNorm_

What are possible causes for the NaN output?

I have been dealing with this error for the past a few days and it is really frustrating. I looked up for similar type error and circumstances and it was suggested by @ptrblck to use these line to find the root of the problem in previous posts in this forum

for name, param in modelstate.model.m.named_parameters():
     if torch.is_tensor(param.grad):
        print(name, torch.isfinite(param.grad).all())

Here is the results for my model:

norm_u.norm.gamma tensor(False, device='cuda:0')
norm_u.norm.beta tensor(False, device='cuda:0')
norm_y.norm.gamma tensor(False, device='cuda:0')
norm_y.norm.beta tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_ih_l0 tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_hh_l0 tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_ih_l0_reverse tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_hh_l0_reverse tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_ih_l1 tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_hh_l1 tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_ih_l1_reverse tensor(False, device='cuda:0')
_latent_encoder._encoder._lstm.weight_hh_l1_reverse tensor(False, device='cuda:0')
_latent_encoder._encoder.fc_out.weight tensor(False, device='cuda:0')
_latent_encoder._encoder.fc_out.bias tensor(False, device='cuda:0')
_latent_encoder._encoder.norm.weight tensor(False, device='cuda:0')
_latent_encoder._encoder.norm.bias tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_ih_l0 tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_hh_l0 tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_ih_l0_reverse tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_hh_l0_reverse tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_ih_l1 tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_hh_l1 tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_ih_l1_reverse tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm._lstm.weight_hh_l1_reverse tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm.fc_out.weight tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm.fc_out.bias tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm.norm.weight tensor(False, device='cuda:0')
_latent_encoder._self_attention._lstm.norm.bias tensor(False, device='cuda:0')
_latent_encoder._self_attention._W.in_proj_weight tensor(False, device='cuda:0')
_latent_encoder._self_attention._W.in_proj_bias tensor(False, device='cuda:0')
_latent_encoder._self_attention._W.out_proj.weight tensor(False, device='cuda:0')
_latent_encoder._self_attention._W.out_proj.bias tensor(False, device='cuda:0')
_latent_encoder._penultimate_layer.weight tensor(False, device='cuda:0')
_latent_encoder._penultimate_layer.bias tensor(False, device='cuda:0')
_latent_encoder._mean.weight tensor(False, device='cuda:0')
_latent_encoder._mean.bias tensor(False, device='cuda:0')
_latent_encoder._log_var.weight tensor(False, device='cuda:0')
_latent_encoder._log_var.bias tensor(False, device='cuda:0')
phi_y.0.weight tensor(False, device='cuda:0')
phi_y.0.bias tensor(False, device='cuda:0')
phi_y.2.weight tensor(False, device='cuda:0')
phi_y.2.bias tensor(False, device='cuda:0')
phi_z.0.weight tensor(False, device='cuda:0')
phi_z.0.bias tensor(False, device='cuda:0')
phi_z.2.weight tensor(False, device='cuda:0')
phi_z.2.bias tensor(False, device='cuda:0')
enc.0.weight tensor(False, device='cuda:0')
enc.0.bias tensor(False, device='cuda:0')
enc.2.weight tensor(False, device='cuda:0')
enc.2.bias tensor(False, device='cuda:0')
enc_mean.weight tensor(False, device='cuda:0')
enc_mean.bias tensor(False, device='cuda:0')
enc_logvar.weight tensor(False, device='cuda:0')
enc_logvar.bias tensor(False, device='cuda:0')
prior.0.weight tensor(False, device='cuda:0')
prior.0.bias tensor(False, device='cuda:0')
prior_mean.weight tensor(False, device='cuda:0')
prior_mean.bias tensor(False, device='cuda:0')
prior_logvar.weight tensor(False, device='cuda:0')
prior_logvar.bias tensor(False, device='cuda:0')
dec.0.weight tensor(False, device='cuda:0')
dec.0.bias tensor(False, device='cuda:0')
dec.2.weight tensor(False, device='cuda:0')
dec.2.bias tensor(False, device='cuda:0')
dec_mean.weight tensor(False, device='cuda:0')
dec_mean.bias tensor(False, device='cuda:0')
dec_scale_diag.0.weight tensor(False, device='cuda:0')
dec_scale_diag.0.bias tensor(False, device='cuda:0')
dec_scale_tril.0.weight tensor(False, device='cuda:0')
dec_scale_tril.0.bias tensor(False, device='cuda:0')
dec_pi.0.weight tensor(False, device='cuda:0')
dec_pi.0.bias tensor(False, device='cuda:0')

I examined the input data in my model and there isn’t any NaN or -inf/+inf value inside the dataset. I also plotted the results LayerNorm_ and it looks perfectly fine. Any further suggestions anyone can provide would be greatly appreciated. :pray:t2:

I’m a bit confused as I understood you have already narrowed the issue down to a single layer in your previous steps?
If not, try to do so and isolate the layer which creates the invalid outputs by checking the intermediate activations. Once this is done, check what exactly the layer applies to the (valid) input to create the (invalid) output. One of these operations must cause the issue.