Update parameters in customized backward function

I’m writing a customized backward function according to Extending PyTorch — PyTorch master documentation.
I plan to add a parameter to be updated in backward function and used in the next iteration of the forward function. Something like this:

class LinearFunction(Function):

@staticmethod
def forward(ctx, input, weight, bias=None):
    ctx.save_for_backward(input, weight, bias)
    output = input.mm(weight.t())
    if bias is not None:
        output += bias.unsqueeze(0).expand_as(output)
    use_new_parameter(ctx.new_parameter)   // use new parameter
    return output

@staticmethod
def backward(ctx, grad_output):
    input, weight, bias = ctx.saved_variables
    grad_input = grad_weight = grad_bias = None
    if ctx.needs_input_grad[0]:
        grad_input = grad_output.mm(weight)
    if ctx.needs_input_grad[1]:
        grad_weight = grad_output.t().mm(input)
    if bias is not None and ctx.needs_input_grad[2]:
        grad_bias = grad_output.sum(0).squeeze(0)
    ctx.new_parameter += 1                               // update parameter


    return grad_input, grad_weight, grad_bias

However, it seems that ctx only stores variables in one iteration. Any idea how can I update parameters in backward? Thank you.

As you noticed, ctx doesn’t support what you’re looking for. Could you please provide a little more context behind what you’re doing?

Hi Richard,

I’m implementing a torch nn.Criterion https://github.com/facebookresearch/wav2letter/blob/master/wav2letter/batchautosegcriterionc.lua in pytorch.

It requires updating parameters in the backward part.

Is it possible to update the parameters outside of the backwards part?

Thank you, Richard. That is exactly what I plan to do now: define parameters outside Function. Is there any nicer way to implement it?
BTW, I’m a little confused with “ctx”. Where is it from?

I think that defining the parameters outside Function is the way to do it and I don’t know of any nicer ways to do it. I think Function isn’t meant to store state between multiple forward/backward passes.

The ctx object in Function's forward and backward gets created when forward is called and destroyed when backward is called (unless retain_graph is used I think, but I’m not too familiar with how that works).

Thank you, Richard. Your comments are very useful.