CUDA Memory Overflow during LSTM + Attention

Hi,

My LSTM Module runs fine without attention, but it runs into CUDA out of memory error with attention. The memory used increases with every forward pass.
Code snippet below. I pass the encoder hidden states to the decoder in a utility function.

LSTM Forward Pass

for di in range(target_length):
                #Get context vectors
                context_vectors = self.attention(lstm_hidden[0]) #Pass only the hidden state
                #combine attention with input
                lstm_input = self.attention_combine(torch.cat((context_vectors, lstm_input), 2))
                lstm_output, lstm_hidden = self.lstm(lstm_input, lstm_hidden) #Feed to lstm
                lstm_input = embeddings[:,di,:].unsqueeze(1) #Next input if teach forcing = false
                decoded_sentences.append(lstm_output)

Attention Code

def attention(self, states):
        '''
        Calculate attention weights
        1. Take one step of the hidden states and calculate for all in the batch 
        2. Linearly transform to one weight
        3. Append attention weights
        '''
        states_temp = states.transpose(1,0) #decoder hidden states and transform shape to (31 * 1 * 1024) !!!
        encoder_state_num = self.encoder_hidden.shape[0] #Number of hidden states in encoder
        
        encoder_hidden_temp = self.encoder_hidden.unsqueeze(0) #encoder hidden shape (1*31*1024)
        states_temp = states_temp.repeat(1, encoder_state_num, 1)
        encoder_hidden_temp = encoder_hidden_temp.repeat(encoder_state_num, 1, 1)
        encoder_decoder_hidden_matrix = torch.cat((states_temp, encoder_hidden_temp), dim=2)

        attn_weights = torch.tanh(self.attention_weights(encoder_decoder_hidden_matrix)).squeeze() #Attention weights shaep (31 *31) ith row indexes decoder hidden states jth column indexes encoder hidden states

        attn_weights = self.mask_softmax(attn_weights, self.mask_vector)

        "Apply Attention Weights"
        #Efficient Implementation
        context_vectors = torch.matmul(attn_weights.permute(1,0).unsqueeze(2), self.encoder_hidden.unsqueeze(1)) #Apply attention weights to encoder_hidden vectors
        context_vectors = torch.sum(context_vectors, 0)
        context_vectors = context_vectors.unsqueeze(1) #context vector shape (31*1*1024) for all sentences - one context vector
        
        return context_vectors

Best,
Soumyadip