How to transition to functions not being allowed to have member variables

Hi,

You only want to save either inputs or outputs with save_for_backward. And they are saved in a slightly complex way.
If you just want to have a state to pass around, you can give a dictionnary to your forward:

class YourFn(Function):
  @staticmethod
    def forward(ctx, arg1, arg2, my_state):
      # This assumes that my_state is NOT a Tensor
      # If it is, you have to use ctx.save_for_backward()
      # or you will see a memory leak
      ctx.my_state = my_state
      # compute the output
      return output

  @staticmethod
  def backward(ctx, grad_output):
    my_state = ctx.my_state
    # compute grad1, grad2
    return grad1, grad2, None

fn_state = {}
output = YourFn.apply(arg1, arg2, fn_state)
2 Likes