I am trying to define a new class using autograd.Function (say f(x, g)).
This function takes a tensor x and another nn.Module (say g with parameters theta).
I will compute the output by applying some variation of g on x, and return the output.
I am a little bit fuzzy on defining the backward for this function.
How should I backpropagate through the function g, and what needs to be returned by the backward function?
class F(autograd.Function):
@staticmethod
def forward(ctx, x, g):
with torch.enable_grad():
x = x.clone().requires_grad_(True)
z = g(x)
ctx.save_for_backward(x, z)
ctx._function = g
return z
@staticmethod
def backward(ctx, output_grad):
x, z = ctx.saved_tensors
g = ctx._function
with torch.enable_grad():
z.backward(??)
return x.grad*output_grad, ??
I know that now this current setup seems a little bit absurd (e.g. why I don’t just apply g directly on x and then backpropagate) but this can be helpful for me.
Thank you very much for your response.
So, as long as I understood, in your solution the function g(.) is assumed to be fixed; because you have returned None as its gradient. How can I also modify the gradient of g(.) assuming that it has some trainable parameters? (what needs to be returned instead of None, since g(.) is not a Tensor, but a module.)