RNN memory management: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

This is a very basic question, sorry. For memory management reasons I’d like to write simple RNN code like:

  s = torch.empty_like(x)     
  s[0] = x[0]
  for t in range(1, x.shape[0]):
    s[t] = decay * s[t-1] + x[t]

Now I understand why I can’t do that, s is one tensor and autograd needs to see a sequence to backprop from the end to the start. So I’ve always done what autograd needs (and is recommended at RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [32, 768]], which is output 0 of TanhBackward, is at version 2; expected version 0 instead. Hint: enable anomaly detectio ):

  s = [x[0]]
  for t in range(1, x.shape[0]):
    s.append(decay * s[t-1] + x[t])
  s = torch.stack(s)

I’m now considering learning the triton language in order to speed up the above code. However, that seems a bit pointless if I’m going to call read all the blocks and write them as one object at the end, the whole point of a triton kernel is would be to take in a single tensor and output a single tensor with minimal memory bandwidth.

I can of course use torch.clone() to get the one tensor in, one tensor out functionality that I want, but this runs massively slower, and so doesn’t seem like the right direction:

s = torch.empty_like(x)
s[0] = x[0]
for t in range(1, x.shape[0]):
  s[t] = decay * s[t-1].clone() + x[t]

All I’m really after is some magic syntax that tells autograd that my output is one tensor but please treat it like a list of independent tensors that just happen to be adjacent in memory and we aren’t really going to have any graph clashes, honest.

so yeah, the next step would definitely be triton or some functional approaches for pytorch. But before you try that, I would personally @torch.compile first. It should support loop unrolling and is a bit better at considering multiple tensors. If that’s still too slow, the next steps are a lot more complex.

1 Like

Thanks a lot.

I’ve been a huge fan of torch.compile() for a long time. torch.compile() doesn’t change the python syntax or the behaviour of autograd, so I didn’t mention it to keep the post minimal. This is a Language Model, and as such is dominated by the final softmax, but just for completeness and also because I said that the torch.clone() version was slower, here are the timings:

torch.stack() solution: eager 1423ms, compiled 275ms

torch.clone() solution: eager 3029ms, compiled 584ms

And your other suggestion was nn.functional. I’d sort of missed the point that nn.functional was split off because it was functional, I’m very glad to have realised this (a bit of a ‘duh’ moment for me).

AHA! I’ve just realised that the code I need is this:

s = torch.empty_like(x)
acc = x[0]
for t in range(x.shape[0]):
  s[t] = acc
  acc = decay * acc + x[t]

acc solution: eager 2006ms, compiled 331ms

Conceptually, acc stays in SRAM and the DRAM has only one copy. The backward pass is the same (but of course the t range is in reverse and I don’t need to keep the individual gradients, I can accumulate for all time). This is the conceptually clean code I was after. Thanks for helping me.