Hi,
I’m using custom RNN kernels and am finding that training uses something like 2x more CUDA memory than with LSTMs. To investigate I recorded CUDA allocations and found that memory allocation grows unreasonably due to repeated calls to torch.mm inside the RNN loop. I was not expecting the same line in my model to repeatedly increase memory use. As a fix, I attempted the use of the “out” keyword to avoid new CUDA allocations, but got: RuntimeError: mm(): functions with out=… arguments don’t support automatic differentiation, but one of the arguments requires grad.
Is there any way to avoid new CUDA allocations when using torch.mm inside a loop? This is a really big problem for me as it severely limits my model size and/or truncated backprop window. Thanks in advance and fingers crossed! (If I have to learn CUDA in order to get lighter memory usage similar to built-in RNN layers that would really suck)
Model forward code below.
@torch.compile
def forward(
self, input_seq: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]:
nseq, nbatch, nx = input_seq.shape
epss = torch.randn((nseq, nbatch, self.hidden_size),device=input_seq.device)
epss = epss.unbind(0)
inputs = input_seq.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
# x_results = torch.randn((nbatch, self.hidden_size))
# z_results = torch.randn((nbatch, self.hidden_size))
for i in range(len(input_seq)):
x = inputs[i]
eps = epss[i]
predicted_distribution = torch.mm(hidden, self.weight_encoder)
mean_, z = predicted_distribution.chunk(2,1)
z = mean_ + eps * torch.exp(0.5*z)
# lines below allocate new CUDA memory for each i
x_results = torch.mm(x, self.weight_ih)
z_results = torch.mm(z, self.weight_zh)
# lines below don't support autograd
# torch.mm(x, self.weight_ih, out=x_results)
# torch.mm(z, self.weight_zh, out=z_results)
r, z, n = x_results.chunk(3, 1)
z_r, z_z, z_n = z_results.chunk(3, 1)
z = torch.sigmoid(z + z_z)
n = torch.tanh(n + torch.sigmoid(r + z_r) * z_n)
hidden = n + torch.mul(z, (hidden - n))
outputs += [hidden]
gc.collect()
return torch.stack(outputs), hidden