I’m writing a customized backward function according to http://pytorch.org/docs/master/notes/extending.html.
I plan to add a parameter to be updated in backward function and used in the next iteration of the forward function. Something like this:

class LinearFunction(Function):

@staticmethod
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
use_new_parameter(ctx.new_parameter) // use new parameter
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
ctx.new_parameter += 1 // update parameter
return grad_input, grad_weight, grad_bias

However, it seems that ctx only stores variables in one iteration. Any idea how can I update parameters in backward? Thank you.

Thank you, Richard. That is exactly what I plan to do now: define parameters outside Function. Is there any nicer way to implement it?
BTW, I’m a little confused with “ctx”. Where is it from?

I think that defining the parameters outside Function is the way to do it and I don’t know of any nicer ways to do it. I think Function isn’t meant to store state between multiple forward/backward passes.

The ctx object in Function's forward and backward gets created when forward is called and destroyed when backward is called (unless retain_graph is used I think, but I’m not too familiar with how that works).