What is the recommended/most efficient way of keeping and updating a buffer of states that I can access, all inside the forward pass of my network.
I wrote a custom RNN layer for multi dimensional sequences, but let’s take a 2D image as example. The layer takes two hidden states (for each dimension) and the input (pixel value) and returns a new hidden state. In the network I have to build a list of these hidden states to use them for the later computations.
Currently I’m initializing a tensor of the size of my images that holds all hidden states with
context_buffer = torch.empty(dim1, dim2, batch_size, self.hidden_size)
and then update it with
context_buffer.data[t1, t2] = h_t
.data because without it the backward pass would fail, but I don’t feel like this is the ideal solution. The network does actually learn, but the backward pass takes ages.
What is the best way to handle this? I looked at pytorch buffers but I feel like they are for storing data shared between several forward passes. Do I just have to change the way I assign the states? Concat them to a initial state instead of filling them in?
Any help appreciated!