Clarification on Backpropagation for LSTMs

I’m trying to train an image captioning model using 2 LSTMs, however, I’m confused with how to properly calculate the loss and make the network learn. The tutorials I see for LSTMs (not necessarily for image captioning) seem to pass in a whole sentence and instantly get back the loss for the whole sentence (as in POS tagging).

For clarification #1:
I was under the impression that you have to pass each word to the LSTM timestep-by-timestep (which means there’s an inner for loop that iterates through each word), and each time, you have to pass in the previous word and the previous hidden state to the LSTM. Is this correct?

For clarification #2:
If #1 is correct, then, at each timestep, you get an output, and you incur a loss (for example, cross-entropy). To make the LSTM learn, do I (1) immediately call loss.backward() but run optimizer.step() after the whole sequence, or (2) accumulate the loss for all the timesteps, and only call accumulated_loss.backward() after the sequence (followed of course by optimizer.step())?

Would really appreciate the help. I’ve trained my network using method (2) but it doesn’t go lower than a certain point. I also tried running it on a small data set to make it overfit and preferrably memorize the training data just to see if it learns sequences, but the outputs don’t show any sentence structure. The highly probable words are almost always the same for all of the timesteps.

Thanks in advance!

Re #1: LSTM takes the whole sequence and performs each time step in the background. However, nothing is stopping you give LSTM just one word at a time. It depends on your task and how you want to implement it.

Re #2: I think (1) is not correct since you backpropagate multiple times over the same past time steps. (2) is the common approach for encoder-decoder models (e.g., autoencoder, machine translation) where the decoder generate words step by step. Sequence tagging is actually a bit more straightforward. If you look at your linked POS tagging tutorial, there is not loop needed. tag_scores = model(sentence_in) contains the output of all time steps. So you can calculate the loss in one go with loss = loss_function(tag_scores, targets)

Given that you want to do image captioning (i.e., generating text), you need a loop and do approach (2): accumulating the loss and call loss.backward() only once. Have a look at the Sequence-to-Sequence tutorial, particularly at the decoder part. Your decoder will roughly look the same. The only difference is that the hidden state won’t come from a text encoder but from some image encoder.

I hope that helps.

Thanks so much for taking the time! This was helpful. I just have two more questions:

  1. The LSTM that you were talking about that performs each time step in the background, specifically, is the nn.LSTM, not nn.LSTMCell, right? This might have been one of my confusions.

  2. If training the LSTM by batch, I pad the captions to have the same length, and the final loss is the mean of the loss of all the targets (excluding padded ones). Should I be doing anything else?

Re #1: No, I was talking solely about LSTM. It wraps an LSTMCell to support multiple layers, dropout, bidirectionality, etc. An LSTMCell only ever takes just one word as input. LSTM is a full layer allowing for whole sequences as output. It’s just that no-one is stoping you to give it sequences of length 1. An LSTM with num_layers=1, bidirectional=False and dropout=0.0that takes one word at a time should be more or less the same as an LSTMCell. My recommendation, stick with LSTM for the time being, and consider LSTMCell if you really need more control of the recursion.

Re #2: Yes, padding is the common technique. Don’t forget to at a fixed EOS (end-of-sequence) token behind each question, so the decoder learns when to stop generating a caption. At the beginning, you can just use padding and see how it works. For more advanced methods, you can check PackedSequence or even sort your dataset in such a way, that each batch has captions of the same length.

In short: Use LSTM and padding to get a basic network training. Then you can see if you want and can improve.

1 Like