I have been training the modified version of variational recurrent neural network. When the loss becomes close to zero the model returns either inf or Nan values. I tried to fix the problem with reducing learning rate of the optimizer, adding gradient clipping and initializing the weights. However, I still get this error. I applied some of the methods suggested in the web to find where the NaN value happens. Here is part of my model and the nan_hook I used to pin point the place that the nan error happened:

def nan_hook(self, inp, out):
    Check for NaN inputs or outputs at each layer in the model
        # forward hook
        for submodule in model.modules():

    outputs = isinstance(out, tuple) and out or [out]
    inputs = isinstance(inp, tuple) and inp or [inp]

    contains_nan = lambda x: torch.isnan(x).any()
    layer = self.__class__.__name__

    for i, inp in enumerate(inputs):
        if inp is not None and contains_nan(inp):
            raise RuntimeError(f'Found NaN input at index: {i} in layer: {layer}')

    for i, out in enumerate(outputs):
        if out is not None and contains_nan(out):
            raise RuntimeError(f'Found NaN output at index: {i} in layer: {layer}')

The main class

class VRNN_GMM(nn.Module):
    def __init__(self,  u_dim, y_dim, h_dim, z_dim, n_layers, n_mixtures, device, batch_norm=False, bias=False):
        super(VRNN_GMM, self).__init__()

        self.y_dim = y_dim
        self.u_dim = u_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.n_layers = n_layers
        self.n_mixtures = n_mixtures
        self.device = device
        self.batch_norm = batch_norm
        # feature-extracting transformations (phi_y, phi_u and phi_z)
        layers_phi_y = [nn.Linear(self.y_dim, self.h_dim)]
        if self.batch_norm:  # Add batch normalization when requested
        layers_phi_y = layers_phi_y + [
            nn.Linear(self.h_dim, self.h_dim),  
        self.phi_y = torch.nn.Sequential(*layers_phi_y)
        layers_phi_u = [nn.Linear(self.u_dim, self.h_dim)]
        if self.batch_norm:  
        layers_phi_u = layers_phi_u + [
            nn.Linear(self.h_dim, self.h_dim),  
        self.phi_u = torch.nn.Sequential(*layers_phi_u)
        layers_phi_z = [nn.Linear(self.z_dim, self.h_dim)]
        if self.batch_norm:
        layers_phi_z = layers_phi_z + [
            nn.Linear(self.h_dim, self.h_dim),
        self.phi_z = torch.nn.Sequential(*layers_phi_z)
        # encoder function (phi_enc) -> Inference
        encoder_layers=[nn.Linear(self.h_dim + self.h_dim, self.h_dim)]
        if self.batch_norm:
            nn.Linear(self.h_dim, self.h_dim),
        if self.batch_norm:
            encoder_layers= encoder_layers+[nn.BatchNorm1d(self.h_dim)]

        encoder_layers_mean= encoder_layers+[
            nn.Linear(self.h_dim, self.z_dim)
        ###encoder log_var
        encoder_layers_logvar= encoder_layers+[
            nn.Linear(self.h_dim, self.z_dim)
        self.enc_logvar = torch.nn.Sequential(*encoder_layers_logvar)
        # prior function (phi_prior) -> Prior
        layers_prior = [nn.Linear(self.h_dim+ self.h_dim, self.h_dim)]
        if self.batch_norm:

        layers_prior_mean = layers_prior + [
            nn.Linear(self.h_dim, self.z_dim),
        self.prior_mean = torch.nn.Sequential(*layers_prior_mean)

        layers_prior_logvar = layers_prior + [
            nn.Linear(self.h_dim, self.z_dim),
        self.prior_logvar = torch.nn.Sequential(*layers_prior_logvar)

        # decoder function (phi_dec) -> Generation
        layers_decoder = [
            nn.Linear(self.h_dim + self.h_dim, self.h_dim)
        if self.batch_norm:
            nn.Linear(self.h_dim, self.h_dim)
        if self.batch_norm:
            nn.Linear(self.h_dim, self.y_dim * self.n_mixtures),
        self.dec_mean = nn.Sequential(*layers_decoder_mean)
            nn.Linear(self.h_dim, self.y_dim * self.n_mixtures),
        self.dec_logvar = nn.Sequential(*layers_decoder_logvar)
            nn.Linear(self.h_dim, self.y_dim * self.n_mixtures),
        if self.batch_norm:
            layers_decoder_pi.append(torch.nn.BatchNorm1d(self.y_dim * self.n_mixtures))
        layers_decoder_pi = layers_decoder_pi + [nn.Softmax(dim=1)]

        self.dec_pi = nn.Sequential(*layers_decoder_pi)

        # recurrence function (f_theta) -> Recurrence
        self.rnn = nn.GRU(self.h_dim + self.y_dim, self.h_dim, self.n_layers, bias)
        self.hidden_state_rnn = nn.GRU(self.y_dim, self.h_dim, self.n_layers, bias)

    def forward(self, u, y):

        batch_size = y.size(0)
        #seq_len = y.shape[-1]#original
        seq_len = torch.LongTensor([x.count_nonzero(dim=-1).count_nonzero().item() for x in u])
        # allocation
        loss = 0
        # initialization
        h = torch.zeros(self.n_layers, batch_size, self.h_dim, device=self.device)

        # for all time steps
        for i in range(len(seq_len)):
            for t in range(seq_len[i]):
                # feature extraction: y_t
                phi_y_t = self.phi_y(y[:, :, t])
                # feature extraction: u_t
                phi_u_t = self.phi_u(u[:, :, t])

                # encoder: y_t, h_t -> z_t posterior 
                encoder_input =[phi_y_t, h[-1]], dim=1)
                enc_mean_t   = self.enc_mean(encoder_input)
                enc_logvar_t = self.enc_logvar(encoder_input)
                enc_logvar_t = nn.Softplus()(enc_logvar_t)

                # prior: h_t -> z_t (for KLD loss)
                prior_input  = h[-1]
                prior_mean_t = self.prior_mean(prior_input)
                prior_logvar_t = self.prior_logvar(prior_input)
                prior_logvar_t = nn.Softplus()(prior_logvar_t)

                # sampling and reparameterization: get a new z_t
                #temp = tdist.Normal(enc_mean_t, enc_logvar_t.exp().sqrt())
                #z_t = tdist.Normal.rsample(temp)
                z_t=self.reparametrization(enc_mean_t, enc_logvar_t)
                # feature extraction: z_t
                phi_z_t = self.phi_z(z_t)

                # decoder: h_t, z_t -> y_t
      [phi_z_t, h[-1]], dim=1)
                dec_mean_t = self.dec_mean(decoder_input).view(batch_size, self.y_dim, self.n_mixtures)
                dec_logvar_t = self.dec_logvar(decoder_input).view(batch_size, self.y_dim, self.n_mixtures)
                dec_logvar_t = nn.Softplus()(dec_logvar_t)

                dec_pi_t = self.dec_pi(decoder_input).view(batch_size, self.y_dim, self.n_mixtures)
                # recurrence: u_t+1, z_t -> h_t+1
                _, h = self.rnn([phi_u_t, phi_z_t], 1).unsqueeze(0), h)

                # computing the loss
                KLD = self.kld_gauss(enc_mean_t, enc_logvar_t, prior_mean_t, prior_logvar_t)

                loss_pred = self.loglikelihood_gmm(y[:, :, t], dec_mean_t, dec_logvar_t, dec_pi_t)
                loss += - loss_pred + KLD

        return loss

Here is the training module:

def run_train(modelstate, loader_train, loader_valid, device, dataframe, path_general, file_name_general, nan_picker=True):
                   'lr_scheduler_nstart':10,#earning rate scheduler start epoch
                   'lr_scheduler_nepochs':5,#check learning rater after
                   'lr_scheduler_factor':10,#adapt learning rate by
                   'min_lr':1e-6,#minimal learning rate
                   'init_lr':1e-4,#initial learning rate
    def validate(loader):
        total_vloss = 0
        total_batches = 0
        total_points = 0
        with torch.no_grad():
            for i, (u, y) in enumerate(loader):
                u =
                y =
                vloss_ = modelstate.model(u, y)

                total_batches += u.size()[0]
                total_points +=
                total_vloss += vloss_.item()

        return total_vloss / total_points  # total_batches

    def train(epoch):
        # model in training mode
        # initialization
        total_loss = 0
        total_batches = 0
        total_points = 0

        for i, (u, y) in enumerate(loader_train):
            u =
            y =
            # set the optimizer
            # forward pass over model
            loss_ = modelstate.model(u, y)
            if nan_picker:
                #debug for nan values
                get_all_layers(modelstate.model.m, nan_hook)

            # NN optimization

            torch.nn.utils.clip_grad_value_(modelstate.model.m.dec_mean.parameters(), train_options['clip'])
            torch.nn.utils.clip_grad_value_(modelstate.model.m.dec_logvar.parameters(), train_options['clip'])
            torch.nn.utils.clip_grad_value_(modelstate.model.m.dec_pi.parameters(), train_options['clip'])
            torch.nn.utils.clip_grad_value_(modelstate.model.m.phi_z.parameters(), train_options['clip'])
            torch.nn.utils.clip_grad_value_(modelstate.model.m.phi_u.parameters(), train_options['clip'])
            torch.nn.utils.clip_grad_value_(modelstate.model.m.phi_y.parameters(), train_options['clip'])
            torch.nn.utils.clip_grad_norm_(modelstate.model.m.prior_mean.parameters(), 1.0, norm_type=2)
            torch.nn.utils.clip_grad_norm_(modelstate.model.m.prior_logvar.parameters(), 1.0, norm_type=2)
            torch.nn.utils.clip_grad_value_(modelstate.model.m.rnn.parameters(), train_options['clip'])
            torch.nn.utils.clip_grad_norm_(modelstate.model.m.enc_mean.parameters(), 1.0, norm_type=2)
            torch.nn.utils.clip_grad_norm_(modelstate.model.m.enc_logvar.parameters(), 1.0, norm_type=2)

            total_batches += u.size()[0]
            total_points +=
            total_loss += loss_.item()

            # output to console
            if i % train_options['print_every'] == 0:
                    'Train Epoch: [{:5d}/{:5d}], Batch [{:6d}/{:6d} ({:3.0f}%)]\tLearning rate: {:.2e}\tLoss: {:.3f}'.format(
                        epoch, train_options['n_epochs'], (i + 1), len(loader_train),
                        100. * (i + 1) / len(loader_train), lr, total_loss / total_points))  # total_batches

        return total_loss / total_points

The error message is given as follows

   1081             if nan_picker:

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_5250/ in forward(self, u, y)
    948             y = self.normalizer_output.normalize(y)
--> 950         loss = self.m(u, y)
    952         return loss

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_5250/ in forward(self, u, y)
    654             for t in range(seq_len[i]):
    655                 # feature extraction: y_t
--> 656                 phi_y_t = self.phi_y(y[:, :, t])
    657                 # feature extraction: u_t
    658                 phi_u_t = self.phi_u(u[:, :, t])

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/ in forward(self, input)
    202     def forward(self, input):
    203         for module in self:
--> 204             input = module(input)
    205         return input
/tmp/ipykernel_5250/ in nan_hook(self, inp, out)
     47     for i, out in enumerate(outputs):
     48         if out is not None and contains_nan(out):
---> 49             raise RuntimeError(f'Found NaN output at index: {i} in layer: {layer}')
     51 def find_modules(nn_module, type):

RuntimeError: Found NaN output at index: 0 in layer: Linear

even though I normalize the data and initialize the linear layer, I still get this error. Is there any other approach to avoid getting this NaNs during the training?

Check where exactly the NaN is created in your custom hooks.
I don’t know which layer “Linear” refers to, but it might be a good idea to check if the input already contains invalid values or the parameters of this layer. Based on this you should be able to narrow down the root cause further (e.g. maybe your input data is corrupt).