RNN saving Hidden state


(Guy Rosenthal) #1

Hi,

I have a long sequence that requires relatively long memory.
I break each sequence to consecutive parts that are fed to the network in the original order. I want to keep the hidden state at the end of each batch, so it would be the initial hidden state to the next batch.

The following code describes my training loop.
Is this the correct way to do it?

for epoch in range(N_epochs):
model.hidden = model.init_hidden(bs=batch_size)
model.last_hidden = None
start_idx = np.arange(0,samples.shape[1],sub_step)
for s_idx in start_idx:
sub_sample = X_batch[:,s_idx:s_idx + sub_step]
sub_target = Y_batch[:,s_idx:s_idx + sub_step]
sample_v = torch.torch.autograd.Variable(sub_sample)
target_v = torch.autograd.Variable(sub_target)

            if model.last_hidden is not None:
                model.hidden = [torch.autograd.Variable(h) for h in model.last_hidden]


            net_output = model(sample_v)

            loss = loss_function(net_output, target_v)
            model.zero_grad()
            loss.backward()timizer.step()

           
            _ = model(sample_v)
      #save last hidden state
            model.last_hidden = [h.data for h in model.hidden]

(jpeg729) #2

It looks mostly functional to me.

A couple of remarks…

First remark: Your data X_batch seems to be of shape (batches, timesteps, features) so I assume that your RNN units are initialised with batch_first=True otherwise you will get funky results.

With X_batch of shape (batches, timesteps, features), sub_sample will not be contiguous so either PyTorch will throw an error complaining about the data being non-contiguous or PyTorch automatically copies the data when you wrap it in a Variable. Basically, if your code runs then it will run inefficiently.

It would be better if X_batch and y_batch were of shape (timesteps, batches, features) because then …

sub_sample = X_batch[s_idx:s_idx + sub_step] # slicing over the time dimension

will create slices of contiguous data and PyTorch won’t have to copy the data to make the Variables contiguous.

Second remark: The line _ = model(sample_v) runs the model on the same input data for a second time, but the model.hidden state has already been updated. So this is unnecessary and will mess with the model’s memory of the past.

I would suggest dispensing with model.last_hidden and maybe using .detach_() instead of repackaging the hidden state

for for s_idx in start_idx:
    # prepare sample_v and target_v
    for h in model.hidden:
        h.detach_() # tells PyTorch not to backpropagate furthur back into the past
    # model.hidden = [Variable(h.data) for h in model.hidden] # would work too
    net_output = model(sample_v)
    loss = loss_function(net_output, target_v)
    model.zero_grad()
    loss.backward()
    optimizer.step()
    # don't rerun the model on sample_v
    # no need to save model.last_hidden

(Guy Rosenthal) #3

@jpeg729 there is something basic I’m missing here regarding the 2nd forward pass you say is not needed. I verified it in toy examples and you are correct by removing it.
the reason I think it’s more accurate to do the second forward pass is that you want the next batch to get a hidden state that reflects the weights update.
meaning, after the weights are updated, rerunning the batch with the updated weights will result a different hidden state. other wise we are feeding the next batch with a hidden state that was generated by the forward pass before the weight update.
can anyone explain to me what am I missing here?


(jpeg729) #4

You have a point. However for efficiency most people don’t do what you suggest. Waiting for the next epoch in order to use the new weights is probably not going to slow down training much, whereas repeating the forward pass will nearly double the time needed for each epoch.

If you do want to repeat the forward pass, then you would have to save the hidden state before running the first forward pass in order to use it for the second. Something like this.

for for s_idx in start_idx:
    # prepare sample_v and target_v
    for h in model.hidden:
        h.detach_() # tells PyTorch not to backpropagate furthur back into the past
    # model.hidden = [Variable(h.data) for h in model.hidden] # would work too
    
    old_hidden = [h.clone() for h in model.hidden]
    
    # do a forward pass and update the weights
    
    # redo a forward pass with the new weights
    model.hidden = old_hidden
    model(sample_v)