Correct way storing states inside one forward pass

What is the recommended/most efficient way of keeping and updating a buffer of states that I can access, all inside the forward pass of my network.

I wrote a custom RNN layer for multi dimensional sequences, but let’s take a 2D image as example. The layer takes two hidden states (for each dimension) and the input (pixel value) and returns a new hidden state. In the network I have to build a list of these hidden states to use them for the later computations.

Currently I’m initializing a tensor of the size of my images that holds all hidden states with

context_buffer = torch.empty(dim1, dim2, batch_size, self.hidden_size)

and then update it with

context_buffer.data[t1, t2] = h_t

I use .data because without it the backward pass would fail, but I don’t feel like this is the ideal solution. The network does actually learn, but the backward pass takes ages.

What is the best way to handle this? I looked at pytorch buffers but I feel like they are for storing data shared between several forward passes. Do I just have to change the way I assign the states? Concat them to a initial state instead of filling them in?

Any help appreciated!

Update:

I now tried appending each new hidden state h_t along one column of the image with

states = torch.cat((states, h_t.unsqueeze(0)))

and then append this tensor to the other columns with

context_buffer = torch.cat((context_buffer, states.unsqueeze(0)))

This also works, but is even slower.

For reference, when passing a batch of 32 images (84x64) the average times are:

  1. Method: 0.5s Forward, 11s Backward
  2. Method: 0.67s Forward, 13s Backward

I really hope someone knows a better solution to this…

Hi,

You don’t want to use .data as it breaks the computational graph and so your gradients won’t be correct.
If you just want to keep a bunch of states, the simplest is to keep them in a python list.
There is no speed advantage in reusing a single buffer in most cases as the custom allocator we use will make the tensors creations almost free.

For the very slow backward, keep in mind that if you do 10 different forwards through your network, it will need to do 10 different backwards which can take a while as the computational graph is quite big.
If you don’t want gradient to flow back, you can use .detach() to prevent that.

1 Like

Hey,

thanks for the answer, it definitely helps, but also confused me a little.

If I store the computed states in a simple python list, I need to detach them first. Does this not break the backward pass when I use them again for the next time step(s)?

This is probably not a PyTorch question, I guess I’ll go wrap my head around the formulas again…

You don’t have to detach() them before putting them in the list :slight_smile:
Be careful though to detach everything that should be and make sure the list empties out not to leak a lot of tensors !

Reviving this thread, because I now encounter a problem that I think is related to your last sentence.

Same project, I’m now using a simple python list to store my states and then generate a tensor from it with torch.stack, for following computations.

Now during training I need loads of memory. Current runs with multiple layers and larger states fail on my universities cluster because they run out of memory (251GB are available on one node).

What exactly do you mean by

make sure the list empties out not to leak a lot of tensors

And how do I achieve this?

From this thread I found out that I apparently store the whole computational graph in my list with each iteration.

How am I supposed to handle this if I need the states for the gradient computation? Can I still detach them at some point?