Hi,
My LSTM Module runs fine without attention, but it runs into CUDA out of memory error with attention. The memory used increases with every forward pass.
Code snippet below. I pass the encoder hidden states to the decoder in a utility function.
LSTM Forward Pass
for di in range(target_length):
#Get context vectors
context_vectors = self.attention(lstm_hidden[0]) #Pass only the hidden state
#combine attention with input
lstm_input = self.attention_combine(torch.cat((context_vectors, lstm_input), 2))
lstm_output, lstm_hidden = self.lstm(lstm_input, lstm_hidden) #Feed to lstm
lstm_input = embeddings[:,di,:].unsqueeze(1) #Next input if teach forcing = false
decoded_sentences.append(lstm_output)
Attention Code
def attention(self, states):
'''
Calculate attention weights
1. Take one step of the hidden states and calculate for all in the batch
2. Linearly transform to one weight
3. Append attention weights
'''
states_temp = states.transpose(1,0) #decoder hidden states and transform shape to (31 * 1 * 1024) !!!
encoder_state_num = self.encoder_hidden.shape[0] #Number of hidden states in encoder
encoder_hidden_temp = self.encoder_hidden.unsqueeze(0) #encoder hidden shape (1*31*1024)
states_temp = states_temp.repeat(1, encoder_state_num, 1)
encoder_hidden_temp = encoder_hidden_temp.repeat(encoder_state_num, 1, 1)
encoder_decoder_hidden_matrix = torch.cat((states_temp, encoder_hidden_temp), dim=2)
attn_weights = torch.tanh(self.attention_weights(encoder_decoder_hidden_matrix)).squeeze() #Attention weights shaep (31 *31) ith row indexes decoder hidden states jth column indexes encoder hidden states
attn_weights = self.mask_softmax(attn_weights, self.mask_vector)
"Apply Attention Weights"
#Efficient Implementation
context_vectors = torch.matmul(attn_weights.permute(1,0).unsqueeze(2), self.encoder_hidden.unsqueeze(1)) #Apply attention weights to encoder_hidden vectors
context_vectors = torch.sum(context_vectors, 0)
context_vectors = context_vectors.unsqueeze(1) #context vector shape (31*1*1024) for all sentences - one context vector
return context_vectors
Best,
Soumyadip