How to transition to functions not being allowed to have member variables

I used to have a torch.autograd.Function that had self variables that could be modified both within the forward and backward calls, and outside the function and the values would carry over correctly. Now I instead have created a new class that just stores these values. I can pass them all into the forward function, save them with ctx.save_for_backward, and use them in backward. However it looks like if I change them outside the function (especially between forward and backward) the values seen in backward don’t change and if I change them within backward the values outside backward don’t change. Is there a recommended method for something like this now that torch.autograd.Function are meant to be static?

Hi,

You only want to save either inputs or outputs with save_for_backward. And they are saved in a slightly complex way.
If you just want to have a state to pass around, you can give a dictionnary to your forward:

class YourFn(Function):
  @staticmethod
    def forward(ctx, arg1, arg2, my_state):
      # This assumes that my_state is NOT a Tensor
      # If it is, you have to use ctx.save_for_backward()
      # or you will see a memory leak
      ctx.my_state = my_state
      # compute the output
      return output

  @staticmethod
  def backward(ctx, grad_output):
    my_state = ctx.my_state
    # compute grad1, grad2
    return grad1, grad2, None

fn_state = {}
output = YourFn.apply(arg1, arg2, fn_state)
2 Likes

Looks like that works! Any additional input for how to make a dictionary work in conjunction with dataParallel to use multiple GPUs? I’m now getting “RuntimeError: The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0” Which I’m guessing is because some values are being split across GPUs and others are not.

asked in a new thread here: Resetting Dataparallel after initialization, and using it in a dictionary

Googling to try to solve my problem (Memory leaks from custom function) I came across this, which is how i was doing it originally which was working haha. Did this answer not take memory leaks into account or has something changed since last year? Or upon further research was this answer assuming the state was normal variables and not other tensors?

Yes, from the question I guess I assumed that the extra arguments were not Tensor. And so that wouldn’t be a problem.
Let me edit the answer to avoid any confusion in the future :slight_smile:

1 Like