I’m writing a customized backward function according to Extending PyTorch — PyTorch master documentation.
I plan to add a parameter to be updated in backward function and used in the next iteration of the forward function. Something like this:
class LinearFunction(Function):
@staticmethod def forward(ctx, input, weight, bias=None): ctx.save_for_backward(input, weight, bias) output = input.mm(weight.t()) if bias is not None: output += bias.unsqueeze(0).expand_as(output) use_new_parameter(ctx.new_parameter) // use new parameter return output @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_variables grad_input = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: grad_input = grad_output.mm(weight) if ctx.needs_input_grad[1]: grad_weight = grad_output.t().mm(input) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0).squeeze(0) ctx.new_parameter += 1 // update parameter return grad_input, grad_weight, grad_bias
However, it seems that ctx only stores variables in one iteration. Any idea how can I update parameters in backward? Thank you.