To start off, the term “memory” is overloaded since I am referring both to GPU RAM as well as a neural network with memory. I am trying to figure out an efficient way to implement a neural network with memory that does not grow the GPU RAM O(N^2). Essentially, I have a function f
that processes my memory tensor, produces a new output, and that output is appended to the memory tensor. This proceeds for s
steps. The approach that runs but has the O(N^2) memory issue is in the script below as memory_issue
. The memory increase arises from torch.cat
which copies all of the previous memory into a new memory tensor.
One workaround I tried is to preallocate the memory, but this causes autograd issues because the memory tensor version doesn’t match the expected version. However, at each step, the network only reads from memory slots that have already been written to, so the activations in those slots are never updated in-place, even though other dimensions are updated in-place. This method is in the script below as version_issue
.
import torch.nn as nn
import torch
batch_size = 4
memory_dim = 64
steps = 2
f = nn.RNN(memory_dim, memory_dim, batch_first=True)
input_ = torch.rand((batch_size, memory_dim), dtype=torch.float32)
def memory_issue(f, input_):
# The second dimension is the memory slot dimension
memory = input_.unsqueeze(1)
for i in range(steps):
print('Step {}'.format(i))
output = f(memory[:, :i+1])
# MEMORY ISSUE: The memory is expanded at each step
memory = torch.cat([memory, output[0][:, -1].unsqueeze(1)], dim=1)
loss = (1 - memory[:, -1]).mean()
return loss
def version_issue(f, input_):
# Pre allocate the entire memory
memory = torch.empty((batch_size, steps+1, memory_dim), dtype=torch.float32)
# The input is placed in the first memory slot
memory[:, 0] = input_
for i in range(steps):
print('Step {}'.format(i))
output = f(memory[:, :i+1])
# VERSION ISSUE: memory._version is now updated, even though the values at memory[:, :i+1] do not change
memory[:, i+1] = output[0][:, -1]
loss = (1 - memory[:, -1]).mean()
return loss
print('Starting memory issue')
loss = memory_issue(f, input_)
loss.backward()
print('Ending memory issue')
print('Starting version issue')
loss = version_issue(f, input_)
loss.backward()
print('Ending version issue')
A computationally non-efficient solution would be to perform the computation is two passes. In the first pass, I calculate the values in memory with gradients turned off. On the second pass, I initialize the memory to the final value from the first pass and then use indexing to only process the correct memory slots at each step.