TypeError: isnan(): argument 'input' (position 1) must be Tensor, not PackedSequence

So it turned out that I got NaN values when I used the typical BatchNorm1d. The time-series data has different sequence length. I followed this discussion and wrote a custom class to implement BatchNorm but I still get NaN values

class MaskedNorm(nn.Module):
    def __init__(self, num_features, mask_on=True):
        """y is the input tensor of shape [batch_size,  time_length,n_channels]
            mask is of shape [batch_size, 1, time_length]
        """
        
        # The process:
        #  1. Merge the batch and time axes using reshape
        #  2. Create a dummy time axis at the end with size 1.
        #  3. Select the valid time steps using the mask
        #  4. Apply BatchNorm1d to the valid time steps
        #  5. Scatter the resulting values to the corresponding positions
        #  6. Unmerge the batch and time axes
        super().__init__()
        self.norm = nn.BatchNorm1d(num_features=num_features)
        self.num_features = num_features
        self.mask_on = mask_on
        #
    def forward(self, y, mask=None):
        #
        self.sequence_length = y.shape[1]
        if self.training and self.mask_on:
            if mask is None:
                seq_len = [torch.max((y[i,:, 0]!=0).nonzero()).item()+1 for i in range(y.shape[0])]
                m  = torch.zeros([y.shape[0],y.shape[1]+1], dtype=torch.bool).to(y.device)
                m[(torch.arange(y.shape[0]), seq_len)] = 1
                m  = m.cumsum(dim=1)[:, :-1] 
                mask = (1-m)
            reshaped = y.reshape([-1, self.num_features, 1])
            reshaped_mask = mask.reshape([-1, 1, 1]) > 0
            selected = torch.masked_select(reshaped, reshaped_mask).reshape([-1, self.num_features, 1])
            batch_normed = self.norm(selected)
            scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
            return scattered.reshape([ -1, self.sequence_length, self.num_features])
        else:
            reshaped = y.reshape([-1, self.num_features, 1])
            batched_normed = self.norm(reshaped)
            return batched_normed.reshape([ -1, self.sequence_length, self.num_features])

The input data before BatchNorm doesn’t have any NaN values. I am wondering why I still get NaN values?