class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2):
# some computation
ctx.save_for_backward(....)
return loss
@staticmethod
def backward(ctx, grad_output):
# some computation
return grad1, grad2
From what I understand, anything that happens to input1 inside of forward is not tracked and the gradient applied is whatever is returned in backward. (For example, if loss = input1 * 2, but grad1 = grad_output * 10, the computation graph won’t care about the first statement and treat the gradient as grad_output * 10). Is my understanding correct? Does torch treat the tensor differently because the operation happens inside a torch.autograd.Function rather than a module?
Assuming this understanding is correct, is there a way to combine both behaviors? In particular, suppose my forward function looks something like this:
# inside class MyFunction(torch.autograd.Function)
def forward(ctx, input1, input2):
a = some_computation_a(input1)
b = some_computation_b(input1)
return a + b
All of the steps of some_computation_a aren’t directly differentiable so I want to supply the gradient to those operations myself, but all of the operations in some_computation_b are, so I’d like torch to track those operations automatically. Is this possible to do?
Does torch treat the tensor differently because the operation happens inside a torch.autograd.Function rather than a module?
No, the Tensors are treated the same way, but everything runs with autograd disable diring the forward.
Is this possible to do?
Of course!
I think the simplest solution is to (and is also a general guideline) keep your custom Function as small as possible: It will make the gradients easier to derive for you.
In particular, you can have a custom function that handles fw/bw for some_computation_a only and then combine it is a regular python function:
I would recommend using a nn.Module for this. Where you can register your learnable parameters as nn.Parameter so that they are properly detected by other nn constructs.
I would recommend using a custom autograd.Function only if you need to work with something that is non-differentiable or you want custom computations (beyond just gradient) to happen during the backward.