Custom RNN cell state saving?

Hi! I’m reading up on PyTorch, and would like to understand a bit better how custom RNN cells work.

  1. Does save_for_backward() work for operations like RNN cells where the same operation instance is forwarded multiple times before backward is called, or does anything special need to be done for this use case?

  2. Are there any limitations as to what operations can be used in an RNN cell, or can it be assumed that all built-in operations will save/restore state as needed (incl. any cuDNN state, if accelerated) such that they work in this multiple forward, then multiple backward case?

Thanks!

  1. Does save_for_backward() work for operations like RNN cells where the same operation instance is forwarded multiple times before backward is called, or does anything special need to be done for this use case?

Yes. It works fine.
This is because RNN cells are of type nn.Module, but save_for_backward is really implemented inside autograd.Function. All autograd Functions inside a graph only have single instantiations. If you have K time steps, you will have K instantiations of the relevant functions. nn.Module are abstractions on top of autograd Functions that make this seamless.

  1. if you are referring to nn.RNN there are limitations documented by the API itself. But you can create your own RNN as you wish.

Thanks!

I noticed that the Dropout function has chosen to save it’s mask as self.noise rather than via self.save_for_backward(noise). Is there any difference between the two? I see plenty of other functions that are using save_for_backward() even when it’s only a single tensor being saved.

self.save_for_backward (or ctx.save_for_backward in the new function API) is needed for saving any tensors that are one of the inputs to the function; for intermediate values (including noise masks) you should use assignment to self.attr or ctx.attr instead.