What is the recommended/most efficient way of keeping and updating a buffer of states that I can access, all inside the forward pass of my network.
I wrote a custom RNN layer for multi dimensional sequences, but let’s take a 2D image as example. The layer takes two hidden states (for each dimension) and the input (pixel value) and returns a new hidden state. In the network I have to build a list of these hidden states to use them for the later computations.
Currently I’m initializing a tensor of the size of my images that holds all hidden states with
I use .data because without it the backward pass would fail, but I don’t feel like this is the ideal solution. The network does actually learn, but the backward pass takes ages.
What is the best way to handle this? I looked at pytorch buffers but I feel like they are for storing data shared between several forward passes. Do I just have to change the way I assign the states? Concat them to a initial state instead of filling them in?
You don’t want to use .data as it breaks the computational graph and so your gradients won’t be correct.
If you just want to keep a bunch of states, the simplest is to keep them in a python list.
There is no speed advantage in reusing a single buffer in most cases as the custom allocator we use will make the tensors creations almost free.
For the very slow backward, keep in mind that if you do 10 different forwards through your network, it will need to do 10 different backwards which can take a while as the computational graph is quite big.
If you don’t want gradient to flow back, you can use .detach() to prevent that.
thanks for the answer, it definitely helps, but also confused me a little.
If I store the computed states in a simple python list, I need to detach them first. Does this not break the backward pass when I use them again for the next time step(s)?
This is probably not a PyTorch question, I guess I’ll go wrap my head around the formulas again…
You don’t have to detach() them before putting them in the list
Be careful though to detach everything that should be and make sure the list empties out not to leak a lot of tensors !
Reviving this thread, because I now encounter a problem that I think is related to your last sentence.
Same project, I’m now using a simple python list to store my states and then generate a tensor from it with torch.stack, for following computations.
Now during training I need loads of memory. Current runs with multiple layers and larger states fail on my universities cluster because they run out of memory (251GB are available on one node).
What exactly do you mean by
make sure the list empties out not to leak a lot of tensors