How to efficiently store a tensor which represents a neural network memory without GPU RAM issues or in-place operation issues

To start off, the term “memory” is overloaded since I am referring both to GPU RAM as well as a neural network with memory. I am trying to figure out an efficient way to implement a neural network with memory that does not grow the GPU RAM O(N^2). Essentially, I have a function f that processes my memory tensor, produces a new output, and that output is appended to the memory tensor. This proceeds for s steps. The approach that runs but has the O(N^2) memory issue is in the script below as memory_issue. The memory increase arises from torch.cat which copies all of the previous memory into a new memory tensor.

One workaround I tried is to preallocate the memory, but this causes autograd issues because the memory tensor version doesn’t match the expected version. However, at each step, the network only reads from memory slots that have already been written to, so the activations in those slots are never updated in-place, even though other dimensions are updated in-place. This method is in the script below as version_issue.

import torch.nn as nn
import torch

batch_size = 4
memory_dim = 64
steps = 2

f = nn.RNN(memory_dim, memory_dim, batch_first=True)
input_ = torch.rand((batch_size, memory_dim), dtype=torch.float32)

def memory_issue(f, input_):
    # The second dimension is the memory slot dimension
    memory = input_.unsqueeze(1)
    for i in range(steps):
        print('Step {}'.format(i))
        output = f(memory[:, :i+1])
        # MEMORY ISSUE: The memory is expanded at each step
        memory = torch.cat([memory, output[0][:, -1].unsqueeze(1)], dim=1)
    loss = (1 - memory[:, -1]).mean()
    return loss

def version_issue(f, input_):
    # Pre allocate the entire memory
    memory = torch.empty((batch_size, steps+1, memory_dim), dtype=torch.float32)
    # The input is placed in the first memory slot
    memory[:, 0] = input_
    for i in range(steps):
        print('Step {}'.format(i))
        output = f(memory[:, :i+1])
        # VERSION ISSUE: memory._version is now updated, even though the values at memory[:, :i+1] do not change
        memory[:, i+1] = output[0][:, -1]

    loss = (1 - memory[:, -1]).mean()
    return loss

print('Starting memory issue')
loss = memory_issue(f, input_)
loss.backward()
print('Ending memory issue')

print('Starting version issue')
loss = version_issue(f, input_)
loss.backward()
print('Ending version issue')

A computationally non-efficient solution would be to perform the computation is two passes. In the first pass, I calculate the values in memory with gradients turned off. On the second pass, I initialize the memory to the final value from the first pass and then use indexing to only process the correct memory slots at each step.

I think that I may have solved this issue using a custom torch.autograd.Function. Does MemorySet look correct, in particular I’m not positive about what to return for the first and second return values in backward. This function allows me to increase my batch size, but I’m noticing a slowdown for the forward/backward pass such that using the original version with O(N^2) memory ends up with a higher throughput (samples/sec) despite a lower batch size.

class MemorySet(torch.autograd.Function):
    @staticmethod
    def forward(ctx, memory, x, index):
        memory.data[:, index].copy_(x)
        ctx.index = index
        return memory

    @staticmethod
    def backward(ctx, grad_out):
        index = ctx.index
        return grad_out, grad_out[:,index], None

memory_set = MemorySet.apply

# Pre allocate the entire memory
memory = torch.zeros((batch_size, steps + 1, memory_dim), dtype=torch.float32)
# The input is placed in the first memory slot
memory = memory_set(memory, input_, 0)
for i in range(steps):
    print('Step {}'.format(i))
    output = f(memory[:, :i + 1])
    memory = memory_set(memory, output[0][:, -1], i + 1)

loss = (1 - memory[:, -1]).mean()
loss.backward()