Batch Training RNNs


If I understand it correctly, when training RNNs using mini batch sgd, the elements in one batch should not be sequential. Rather, every index throughout the batches corresponds to one sequence. I can see that this makes sense when one has multiple sequences to train on.

Currently I’m working on a problem where I have only 1 ongoing time series, no individual sequences. Is it common practice to artificially create sequences by splitting the training data? Or should I train with batch size 1 and consider my whole training set as one sequence?

It would be cool if someone could confirm my understanding of the matter and/or maybe give some some tips on best practices for this sort of situation.




Either approach would work. Training will run quicker with batches made up of subsequences.

1 Like

I like to make sure that the subsequences in each batch are the continuations of the subsequences in the previous batch. That way I can use the final hidden state of one batch as the initial hidden state of the next batch.

If you do that then you must either

  • detach the hidden state between batches in order to cut off backpropagation. The fact that the hidden state is retained allows the model to learn something about the influence of previous history, but not as much as full BPTT would allow.
  • or use the retain_graph=True option when calling .backward(), but in this case each batch will take longer than the previous batch because it will be backpropagating all the way through time to the beginning of the first batch.

@jpeg729 I am trying to train a model the way you are suggesting and I have a couple questions about it.

1- If i call loss.backward() after every batch wouldnt this reset the graph and therefore detach the hidden_state so that it won’t be used in the next batch?

2- At the beginning of each batch i reset the LSTM hidden state and cell state to 0. Is this a valid approach if I want to exclusively remember one batch?

3- If i want to retain the last hidden state from the previous batch and avoid backpropagating through the beginning. Couldn’t I just call loss.backward() at the end of one batch and leave the hidden state untouched?

I would love some clarification regarding this. Thanks in advance!

Thanks in advance!

1 & 3. The computation graph is created during the forward pass, and used by the backward pass. Normally the backward pass deletes the graph. The hidden state is the result of the calculation of the previous batch, and it remembers this fact even after the computation graph has been deleted, hence the need to explicitly detach the hidden state. You can try not detaching it and simply calling backward() but you will get an error saying the computation graph has already been deleted.

  1. Yes.

1 & 3. So the hidden state is not part of the graph? If it is why is it not detached on the .backward() call? If it’s not then why is it not?
I thought the cell and hidden state where a result of operations applied to weights plus some non-linearity functions inside each lstm cell, should’t this make it part of the graph? I feel like I am massively confused about this.

The hidden state is part of the graph.
When you call backward after the first batch the computation graph is freed, but the hidden state still remembers that it was computed, it just no longer knows how. Which is why if you don’t explicitly detach the hidden state before the second batch, then the second call to backward will fail with an error.

1 Like

I create new Variables for the hidden state setting it to 0 at the beginning of each batch since these are new Variable objects they are not part of the graph right? Im guessing this is why i dont get an error when calling .backward() on batch number 2. Is this a correct approach if im treating each batch individually as explained in case 2? Or will this affect training?

Yep. If you create new variables, then they aren’t part of any graph yet, so you don’t need to detach them. The only downside will be that the model will have little to no understanding of dependencies in the data that are longer than the sequence length used for training.

1 Like

Maybe that is why im not getting good results. I will try keeping the hidden state and detaching it between batches so the lstm has a notion of the previous sequences. Maybe this will solve my problem. You are a life saver! Thanks a lot!!

If you data is stock market data, then that won’t help much. Stock market prices are full of noise with little predictability.

Best of luck.

I am just trying to generate text so your approach will help a lot i am guessing.

Thanks for the discussion and the insightful answers @jpeg729. I will try this tonight when I’m home.

1 Like

Hey! Sorry, I got one more question.
I’m a little confused about the seq_len dimension in the RNN type modules’ inputs and outputs. I was looking at nn.LSTM and nn.GRU. From their documentation I could assume that they would accept the whole training corpus as input at once (which doesn’t exactly seem sensible), considering they have both a sequence and batch dimension. Am I maybe supposed to split my sequences into smaller subsequences, meaning one training iteration would get a batch of batches of training samples as input? I’m sorry if this is a newbie-question, I have only ever trained vanilla nets and am a little confused about the RNN API at training time.

I think I’ll look at the cell-implementations for now since they don’t seem to have this sequence-dimension requirement.

The cell implementations take one timestep at a time. The LSTM, RNN and GRU all take inputs with several timesteps in one go. I find it helpful to be very clear about the distinction between the batch dimension, whose indices correspond to different input sequences, and the sequence dimension whose indices correspond to different timesteps of each sequence. The term “batch” can be a little ambiguous if you are not careful.

I believe that the GPU implementation of LSTM is somewhat faster that the cell implementation, but I have never tried it myself.

I think most people either use LSTMs for short sequences (e.g. sentences to be classified) and they pass entire sequences to the LSTM in one go, or for long sequences which they cut up into subsequences.

I have found that if I take a really long sequence and divide it up into subsequences which I stack into batches, then I can get through an epoch of training much faster. That said, I didn’t start out by doing that, I started with simpler methods and I went from there.

1 Like

I think I just have some very fundamental knowledge gaps when it comes to what RNNs inputs have to look like during training time, specifically for this single time series case I’m looking at now. I feel like blog posts or articles about RNNs usually focus on the inputs during prediction time. When they talk about training time its mostly straight to BPTT, which I guess one could infer the training data’s shape from but somehow I haven’t been able to do that. I’ll just keep trying, thanks for your help and sorry that it didn’t amount to more yet lol

If you have any more questions don’t hesitate.

1 Like

Alright, I kinda feel bad asking all these stupid questions, but well, you asked for it :slight_smile:
I might have either kind of understood it now or gone fully mad.

Lets say I have one “big” overall sequence of steps 1,2,…,8
If I pick batch size = 4, that means I have created 4 subsequences and 2 batches
— b1 b2
s1 1, 2,
s2 3, 4,
s3 5, 6,
s4 7, 8

bn = batch n, sn = subsequence n

I guess I can now decide how many batches I should use per training iteration; the number of batches I input into my network per training iteration determines how long the training sequence used for an update is. This number of batches is the seq_len dimension in the RNN input. I guess the appropriate term for the sequence used in one training iteration would be “subsubsequence”? So if I for example chose to only use one batch per training iteration I would use a subsequence of each of the 4 subsequences s1-s4.

Assuming what I wrote until now is actually correct and not completely insane I would do updates using subsubsequences. I guess I should reset the hidden states after inputting all the batches once and thereby finishing an epoch. But I guess I should not reset the hidden states after every single update since I only updated based on a subsubsequence and it would make sense that the following subsubsequence would use its subsequence’s previous hidden state. Given that this is sensible I suppose I would have to save the hidden state at the end of every training iteration and use it to initialize the hidden state during the next training iteration.

Is this even close to correct? If not feel free to tell me where I’ve gone wrong. Thanks!

I will try to clarify what I think you are asking as best as I can .

1 - yes

2- No. The seq_len parameter is the size of your batch, the number of batches corresponds to the batch dimension in the LSTM input.

3- Yes. You can decide up to how many batches you want to backpropagate. If you have only 4 batches and you want to backpropagate through all of them until the beginning of your sequence then I’m guessing you can call loss.backward(retain_graph=True) at the end of every batch. I don’t think this is very common since when having lots of data it will be very slow.

4 -When to reset the hidden state depends on what you want. You should reset your hidden state when you want your RNN to not have any notion at all of what happened earlier. If you think its important to remember all of the data then you reset the hidden state after a full epoch but if one batch is independent from the rest then you can reset it after each .backward(). My questions and @jpeg729 answers on this same thread could clarify or expand on this.

That said, I am a bit of a newbie so I hope someone with more knowledge could confirm or correct my answer. Especially points 3 & 4.

Lastly, never hesitate to ask these kind of questions. This is the very purpose of these forums and furthermore they help newbies like me to understand things better.

1 Like

That would be right if you were using the cell implementations, if you want to use the LSTM implementation then it would be better to say that you have created “one batch of 4 subsequences of 2 timesteps each”, batch_size=4, seq_len=2.

Batches and subsequences get mixed up very easily. I don’t think I can reply to @Diego’s answer without confusing thing even more. So first some theory, then an example, then some code.

LSTM expects input of shape (seq_len, batch, input_size).

  • seq_len = the length of each subsequence.
  • batch = how many subsequences it will process in parallel.
  • input_size = how many values are there in each sequence at each timestep.

Now for an example

  • Suppose I have one big sequence of steps 0,1,2,…99 with f features at each timestep. The big sequence is a tensor of size (100, f).

If I want to train on subsequences of length 5 then I cut my big sequence into subsequences of length 5. In this case there are 20 subsequences. Then if I want to train with a batch size of 10, I take 10 of these subsequences.

In code… given a tensor sequence of shape (timesteps, features)

subsequences = sequence.split(desired_subsequence_length)
# each subsequence is of size (desired_subsequence_length, features)

# if the last subsequence is not of the same length, remove it.
if subsequences[-1].size(0) != desired_subsequence_length:
    subsequences = subsequences[:-1]
    # alternatively you could pad it, but I am not sure of the best method

for b in range(len(subsequences) / desired_batch_size):
    training_batch =[b:b+desired_batch_size], dim=1)
    # training_batch is of size (desired_subsequence_length, desired_batch_size, features)
    # the last batch might contain fewer subsequences, but that doesn't matter here
    output = model(training_batch)

With the above code, you must reset the hidden state before each batch because the batches don’t follow on from one another properly. If you want an example with batches that follow on properly, I can give you one, but I’ll wait until these concepts are in place.