CUDA Memory Overflow during LSTM + Attention


My LSTM Module runs fine without attention, but it runs into CUDA out of memory error with attention. The memory used increases with every forward pass.
Code snippet below. I pass the encoder hidden states to the decoder in a utility function.

LSTM Forward Pass

for di in range(target_length):
                #Get context vectors
                context_vectors = self.attention(lstm_hidden[0]) #Pass only the hidden state
                #combine attention with input
                lstm_input = self.attention_combine(, lstm_input), 2))
                lstm_output, lstm_hidden = self.lstm(lstm_input, lstm_hidden) #Feed to lstm
                lstm_input = embeddings[:,di,:].unsqueeze(1) #Next input if teach forcing = false

Attention Code

def attention(self, states):
        Calculate attention weights
        1. Take one step of the hidden states and calculate for all in the batch 
        2. Linearly transform to one weight
        3. Append attention weights
        states_temp = states.transpose(1,0) #decoder hidden states and transform shape to (31 * 1 * 1024) !!!
        encoder_state_num = self.encoder_hidden.shape[0] #Number of hidden states in encoder
        encoder_hidden_temp = self.encoder_hidden.unsqueeze(0) #encoder hidden shape (1*31*1024)
        states_temp = states_temp.repeat(1, encoder_state_num, 1)
        encoder_hidden_temp = encoder_hidden_temp.repeat(encoder_state_num, 1, 1)
        encoder_decoder_hidden_matrix =, encoder_hidden_temp), dim=2)

        attn_weights = torch.tanh(self.attention_weights(encoder_decoder_hidden_matrix)).squeeze() #Attention weights shaep (31 *31) ith row indexes decoder hidden states jth column indexes encoder hidden states

        attn_weights = self.mask_softmax(attn_weights, self.mask_vector)

        "Apply Attention Weights"
        #Efficient Implementation
        context_vectors = torch.matmul(attn_weights.permute(1,0).unsqueeze(2), self.encoder_hidden.unsqueeze(1)) #Apply attention weights to encoder_hidden vectors
        context_vectors = torch.sum(context_vectors, 0)
        context_vectors = context_vectors.unsqueeze(1) #context vector shape (31*1*1024) for all sentences - one context vector
        return context_vectors