Infinite validation loss, even without shuffling data

I’m developing a LSTM AutoEncoder to encode text data. I’ve trained Flair's DocumentRNNEmbeddings to embed sequence of sentences and using the saved model.

Each batch of my training data is of the shape [12,5,2048]; i.e; each sample contains 5 sentences having dimensions 2048.

PROBLEM

I’ve tried both BCEWithLogitLoss and MSELoss along with Adam Optimizer ; in both the cases my training loss was fairly good 0.31 on average with BCE and 0.011 on average with MSE but my validation loss mostly ends up being inf or in some rare cases really huge: 8.025489465e+32. I’m just trying to validate my approach of Auto Encoder, so my entire training data contains only 335 Sentences (6 batches of data ) and test data 60 Sentences (only two batches of data) ; I’m not sure if that’s what is causing the issue or if there’s something wrong with the way I’m performing forward pass.

Important parts of network initialization:

    self.lstm_enc = nn.LSTM(self.input_dim, self.hidden_dim, 
                            self.num_layers, batch_first = True,
                            dropout = self.dropout_rate,
                            bidirectional = self.bidirectional)
            
        self.lstm_dec = nn.LSTM(self.hidden_dim*2, self.output_dim // 2, 
                        self.num_layers, batch_first = True,
                        dropout = self.dropout_rate,
                        bidirectional = self.bidirectional)

    self.hidden_rep = get_hidden_rep
    self.activation = nn.ELU()
    
    self.hidden_enc_weights = self._init_hidden_enc_weights()
            
    self.hidden_dec_weights = self._init_hidden_dec_weights()   
    
   # self.loss_func = nn.BCEWithLogitsLoss()

    self.loss_func  = nn.MSELoss(reduction = 'mean')

Here’s how my forward pass looks like:

        encoder_output, self.hidden_enc_weights = self.lstm_enc(X_batch, self.hidden_enc_weights)
        self.hidden_enc_weights[0].clamp(min = 1e-1);self.hidden_enc_weights[1].clamp(min = 1e-1) 
        self.hidden_enc_weights[0].detach_(); self.hidden_enc_weights[1].detach_()
        
        
        encoder_output = self.activation(encoder_output)
        
        
        if self.step_through_linear:
            encoder_output = self.step_linear(encoder_output)
            encoder_output = self.activation(encoder_output) 
                
        decoder_output, self.hidden_dec_weights = self.lstm_dec(encoder_output, self.hidden_dec_weights)

        self.hidden_dec_weights[0].clamp(min = 0);self.hidden_dec_weights[1].clamp(min = 0) 
        self.hidden_dec_weights[0].detach_(); self.hidden_dec_weights[1].detach_()
        
        
        if self.hidden_rep:
            return encoder_output
        
        else:
            
            decoder_output = self.activation( decoder_output )
            
            #return self.loss_func(decoder_output, torch.flip(X_batch, [0])) # BCEWithLogitLoss
            return self.loss_func(X_batch, decoder_output) # MSELOSS

I’ve applied some new techniques, such as reversing the input data while comparing the loss, and also other techniques such as initializing with xavier weights and clamping weights of encoder and decoder weights. But, I don’t think they’d affect my loss in a negative way.

WHAT I’VE TRIED TO COUNTER

I did find similar issues such as : Similar Issue Infinite Loss ; I implemented some checks in my code from suggested answers, such as checking for inf values in my input and validation values, ex:

torch.isfinite(eval_batch).all().item()

The return value always suggested that there were no inf values in my data in both the sets. So, I’m not able to figure out what’s causing the issue. Any help is much appreciated.

P.S: I moved to pytorch recently from Keras.
TIA !

Please don’t tag specific people, as this might discourage others to post an answer.

2 Likes