Hi,
You only want to save either inputs or outputs with save_for_backward
. And they are saved in a slightly complex way.
If you just want to have a state to pass around, you can give a dictionnary to your forward:
class YourFn(Function):
@staticmethod
def forward(ctx, arg1, arg2, my_state):
# This assumes that my_state is NOT a Tensor
# If it is, you have to use ctx.save_for_backward()
# or you will see a memory leak
ctx.my_state = my_state
# compute the output
return output
@staticmethod
def backward(ctx, grad_output):
my_state = ctx.my_state
# compute grad1, grad2
return grad1, grad2, None
fn_state = {}
output = YourFn.apply(arg1, arg2, fn_state)